From 47b7ebad12a17218f6ca0301fc802c0e0a81d873 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 28 Jul 2012 20:03:26 -0700 Subject: Added the Spark Streaing code, ported to Akka 2 --- core/src/main/scala/spark/BlockRDD.scala | 42 ++ core/src/main/scala/spark/SparkContext.scala | 5 + project/SparkBuild.scala | 6 +- run | 2 + sentences.txt | 3 + startTrigger.sh | 3 + .../src/main/scala/spark/stream/BlockID.scala | 20 + .../scala/spark/stream/ConnectionHandler.scala | 157 ++++++ .../spark/stream/DumbTopKWordCount2_Special.scala | 138 +++++ .../spark/stream/DumbWordCount2_Special.scala | 92 ++++ .../scala/spark/stream/FileStreamReceiver.scala | 70 +++ .../src/main/scala/spark/stream/GrepCount.scala | 39 ++ .../src/main/scala/spark/stream/GrepCount2.scala | 113 ++++ .../main/scala/spark/stream/GrepCountApprox.scala | 54 ++ .../main/scala/spark/stream/IdealPerformance.scala | 36 ++ .../src/main/scala/spark/stream/Interval.scala | 75 +++ streaming/src/main/scala/spark/stream/Job.scala | 21 + .../src/main/scala/spark/stream/JobManager.scala | 112 ++++ .../src/main/scala/spark/stream/JobManager2.scala | 37 ++ .../scala/spark/stream/NetworkStreamReceiver.scala | 184 +++++++ streaming/src/main/scala/spark/stream/RDS.scala | 607 +++++++++++++++++++++ .../scala/spark/stream/ReducedWindowedRDS.scala | 218 ++++++++ .../src/main/scala/spark/stream/Scheduler.scala | 181 ++++++ .../stream/SenGeneratorForPerformanceTest.scala | 78 +++ .../scala/spark/stream/SenderReceiverTest.scala | 63 +++ .../scala/spark/stream/SentenceFileGenerator.scala | 92 ++++ .../scala/spark/stream/SentenceGenerator.scala | 103 ++++ .../src/main/scala/spark/stream/ShuffleTest.scala | 22 + .../main/scala/spark/stream/SimpleWordCount.scala | 30 + .../main/scala/spark/stream/SimpleWordCount2.scala | 51 ++ .../spark/stream/SimpleWordCount2_Special.scala | 83 +++ .../scala/spark/stream/SparkStreamContext.scala | 105 ++++ .../main/scala/spark/stream/TestGenerator.scala | 107 ++++ .../main/scala/spark/stream/TestGenerator2.scala | 119 ++++ .../main/scala/spark/stream/TestGenerator4.scala | 244 +++++++++ .../scala/spark/stream/TestInputBlockTracker.scala | 42 ++ .../scala/spark/stream/TestStreamCoordinator.scala | 38 ++ .../scala/spark/stream/TestStreamReceiver3.scala | 420 ++++++++++++++ .../scala/spark/stream/TestStreamReceiver4.scala | 373 +++++++++++++ streaming/src/main/scala/spark/stream/Time.scala | 85 +++ .../main/scala/spark/stream/TopContentCount.scala | 97 ++++ .../main/scala/spark/stream/TopKWordCount2.scala | 103 ++++ .../spark/stream/TopKWordCount2_Special.scala | 142 +++++ .../src/main/scala/spark/stream/WindowedRDS.scala | 68 +++ .../src/main/scala/spark/stream/WordCount.scala | 62 +++ .../src/main/scala/spark/stream/WordCount1.scala | 46 ++ .../src/main/scala/spark/stream/WordCount2.scala | 55 ++ .../scala/spark/stream/WordCount2_Special.scala | 94 ++++ .../src/main/scala/spark/stream/WordCount3.scala | 49 ++ .../src/main/scala/spark/stream/WordCountEc2.scala | 41 ++ .../spark/stream/WordCountTrivialWindow.scala | 51 ++ .../src/main/scala/spark/stream/WordMax.scala | 64 +++ 52 files changed, 5141 insertions(+), 1 deletion(-) create mode 100644 core/src/main/scala/spark/BlockRDD.scala create mode 100644 sentences.txt create mode 100755 startTrigger.sh create mode 100644 streaming/src/main/scala/spark/stream/BlockID.scala create mode 100644 streaming/src/main/scala/spark/stream/ConnectionHandler.scala create mode 100644 streaming/src/main/scala/spark/stream/DumbTopKWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/stream/DumbWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/stream/FileStreamReceiver.scala create mode 100644 streaming/src/main/scala/spark/stream/GrepCount.scala create mode 100644 streaming/src/main/scala/spark/stream/GrepCount2.scala create mode 100644 streaming/src/main/scala/spark/stream/GrepCountApprox.scala create mode 100644 streaming/src/main/scala/spark/stream/IdealPerformance.scala create mode 100644 streaming/src/main/scala/spark/stream/Interval.scala create mode 100644 streaming/src/main/scala/spark/stream/Job.scala create mode 100644 streaming/src/main/scala/spark/stream/JobManager.scala create mode 100644 streaming/src/main/scala/spark/stream/JobManager2.scala create mode 100644 streaming/src/main/scala/spark/stream/NetworkStreamReceiver.scala create mode 100644 streaming/src/main/scala/spark/stream/RDS.scala create mode 100644 streaming/src/main/scala/spark/stream/ReducedWindowedRDS.scala create mode 100644 streaming/src/main/scala/spark/stream/Scheduler.scala create mode 100644 streaming/src/main/scala/spark/stream/SenGeneratorForPerformanceTest.scala create mode 100644 streaming/src/main/scala/spark/stream/SenderReceiverTest.scala create mode 100644 streaming/src/main/scala/spark/stream/SentenceFileGenerator.scala create mode 100644 streaming/src/main/scala/spark/stream/SentenceGenerator.scala create mode 100644 streaming/src/main/scala/spark/stream/ShuffleTest.scala create mode 100644 streaming/src/main/scala/spark/stream/SimpleWordCount.scala create mode 100644 streaming/src/main/scala/spark/stream/SimpleWordCount2.scala create mode 100644 streaming/src/main/scala/spark/stream/SimpleWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/stream/SparkStreamContext.scala create mode 100644 streaming/src/main/scala/spark/stream/TestGenerator.scala create mode 100644 streaming/src/main/scala/spark/stream/TestGenerator2.scala create mode 100644 streaming/src/main/scala/spark/stream/TestGenerator4.scala create mode 100644 streaming/src/main/scala/spark/stream/TestInputBlockTracker.scala create mode 100644 streaming/src/main/scala/spark/stream/TestStreamCoordinator.scala create mode 100644 streaming/src/main/scala/spark/stream/TestStreamReceiver3.scala create mode 100644 streaming/src/main/scala/spark/stream/TestStreamReceiver4.scala create mode 100644 streaming/src/main/scala/spark/stream/Time.scala create mode 100644 streaming/src/main/scala/spark/stream/TopContentCount.scala create mode 100644 streaming/src/main/scala/spark/stream/TopKWordCount2.scala create mode 100644 streaming/src/main/scala/spark/stream/TopKWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/stream/WindowedRDS.scala create mode 100644 streaming/src/main/scala/spark/stream/WordCount.scala create mode 100644 streaming/src/main/scala/spark/stream/WordCount1.scala create mode 100644 streaming/src/main/scala/spark/stream/WordCount2.scala create mode 100644 streaming/src/main/scala/spark/stream/WordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/stream/WordCount3.scala create mode 100644 streaming/src/main/scala/spark/stream/WordCountEc2.scala create mode 100644 streaming/src/main/scala/spark/stream/WordCountTrivialWindow.scala create mode 100644 streaming/src/main/scala/spark/stream/WordMax.scala diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/BlockRDD.scala new file mode 100644 index 0000000000..ea009f0f4f --- /dev/null +++ b/core/src/main/scala/spark/BlockRDD.scala @@ -0,0 +1,42 @@ +package spark + +import scala.collection.mutable.HashMap + +class BlockRDDSplit(val blockId: String, idx: Int) extends Split { + val index = idx +} + + +class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { + + @transient + val splits_ = (0 until blockIds.size).map(i => { + new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] + }).toArray + + @transient + lazy val locations_ = { + val blockManager = SparkEnv.get.blockManager + /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ + val locations = blockManager.getLocations(blockIds) + HashMap(blockIds.zip(locations):_*) + } + + override def splits = splits_ + + override def compute(split: Split): Iterator[T] = { + val blockManager = SparkEnv.get.blockManager + val blockId = split.asInstanceOf[BlockRDDSplit].blockId + blockManager.get(blockId) match { + case Some(block) => block.asInstanceOf[Iterator[T]] + case None => + throw new Exception("Could not compute split, block " + blockId + " not found") + } + } + + override def preferredLocations(split: Split) = + locations_(split.asInstanceOf[BlockRDDSplit].blockId) + + override val dependencies: List[Dependency[_]] = Nil +} + diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index dd17d4d6b3..78c7618542 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -409,6 +409,11 @@ class SparkContext( * various Spark features. */ object SparkContext { + + // TODO: temporary hack for using HDFS as input in streaing + var inputFile: String = null + var idealPartitions: Int = 1 + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 726d490738..c4ada2bf2a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -8,7 +8,7 @@ object SparkBuild extends Build { // "1.0.1" for Apache releases, or "0.20.2-cdh3u3" for Cloudera Hadoop. val HADOOP_VERSION = "0.20.205.0" - lazy val root = Project("root", file("."), settings = sharedSettings) aggregate(core, repl, examples, bagel) + lazy val root = Project("root", file("."), settings = sharedSettings) aggregate(core, repl, examples, bagel, streaming) lazy val core = Project("core", file("core"), settings = coreSettings) @@ -18,6 +18,8 @@ object SparkBuild extends Build { lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn (core) + lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn (core) + def sharedSettings = Defaults.defaultSettings ++ Seq( organization := "org.spark-project", version := "0.6.0-SNAPSHOT", @@ -82,6 +84,8 @@ object SparkBuild extends Build { def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") + def streamingSettings = sharedSettings ++ Seq(name := "spark-streaming") + def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard diff --git a/run b/run index 8f7256b4e5..e3e98f4280 100755 --- a/run +++ b/run @@ -46,6 +46,7 @@ CORE_DIR="$FWDIR/core" REPL_DIR="$FWDIR/repl" EXAMPLES_DIR="$FWDIR/examples" BAGEL_DIR="$FWDIR/bagel" +STREAMING_DIR="$FWDIR/streaming" # Build up classpath CLASSPATH="$SPARK_CLASSPATH" @@ -55,6 +56,7 @@ CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$CORE_DIR/src/main/resources" CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" for jar in `find $CORE_DIR/lib -name '*jar'`; do CLASSPATH+=":$jar" done diff --git a/sentences.txt b/sentences.txt new file mode 100644 index 0000000000..fedf96c66e --- /dev/null +++ b/sentences.txt @@ -0,0 +1,3 @@ +Hello world! +What's up? +There is no cow level diff --git a/startTrigger.sh b/startTrigger.sh new file mode 100755 index 0000000000..0afce91a3e --- /dev/null +++ b/startTrigger.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +./run spark.stream.SentenceGenerator localhost 7078 sentences.txt 1 diff --git a/streaming/src/main/scala/spark/stream/BlockID.scala b/streaming/src/main/scala/spark/stream/BlockID.scala new file mode 100644 index 0000000000..a3fd046c9a --- /dev/null +++ b/streaming/src/main/scala/spark/stream/BlockID.scala @@ -0,0 +1,20 @@ +package spark.stream + +case class BlockID(sRds: String, sInterval: Interval, sPartition: Int) { + override def toString : String = ( + sRds + BlockID.sConnector + + sInterval.beginTime + BlockID.sConnector + + sInterval.endTime + BlockID.sConnector + + sPartition + ) +} + +object BlockID { + val sConnector = '-' + + def parse(name : String) = BlockID( + name.split(BlockID.sConnector)(0), + new Interval(name.split(BlockID.sConnector)(1).toLong, + name.split(BlockID.sConnector)(2).toLong), + name.split(BlockID.sConnector)(3).toInt) +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/stream/ConnectionHandler.scala b/streaming/src/main/scala/spark/stream/ConnectionHandler.scala new file mode 100644 index 0000000000..73b82b76b8 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/ConnectionHandler.scala @@ -0,0 +1,157 @@ +package spark.stream + +import spark.Logging + +import scala.collection.mutable.{ArrayBuffer, SynchronizedQueue} + +import java.net._ +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ +import java.nio.channels.spi._ + +abstract class ConnectionHandler(host: String, port: Int, connect: Boolean) +extends Thread with Logging { + + val selector = SelectorProvider.provider.openSelector() + val interestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + + initLogging() + + override def run() { + try { + if (connect) { + connect() + } else { + listen() + } + + var interrupted = false + while(!interrupted) { + + preSelect() + + while(!interestChangeRequests.isEmpty) { + val (key, ops) = interestChangeRequests.dequeue + val lastOps = key.interestOps() + key.interestOps(ops) + + def intToOpStr(op: Int): String = { + val opStrs = new ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed ops from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } + + selector.select() + interrupted = Thread.currentThread.isInterrupted + + val selectedKeys = selector.selectedKeys().iterator() + while (selectedKeys.hasNext) { + val key = selectedKeys.next.asInstanceOf[SelectionKey] + selectedKeys.remove() + if (key.isValid) { + if (key.isAcceptable) { + accept(key) + } else if (key.isConnectable) { + finishConnect(key) + } else if (key.isReadable) { + read(key) + } else if (key.isWritable) { + write(key) + } + } + } + } + } catch { + case e: Exception => { + logError("Error in select loop", e) + } + } + } + + def connect() { + val socketAddress = new InetSocketAddress(host, port) + val channel = SocketChannel.open() + channel.configureBlocking(false) + channel.socket.setReuseAddress(true) + channel.socket.setTcpNoDelay(true) + channel.connect(socketAddress) + channel.register(selector, SelectionKey.OP_CONNECT) + logInfo("Initiating connection to [" + socketAddress + "]") + } + + def listen() { + val channel = ServerSocketChannel.open() + channel.configureBlocking(false) + channel.socket.setReuseAddress(true) + channel.socket.setReceiveBufferSize(256 * 1024) + channel.socket.bind(new InetSocketAddress(port)) + channel.register(selector, SelectionKey.OP_ACCEPT) + logInfo("Listening on port " + port) + } + + def finishConnect(key: SelectionKey) { + try { + val channel = key.channel.asInstanceOf[SocketChannel] + val address = channel.socket.getRemoteSocketAddress + channel.finishConnect() + logInfo("Connected to [" + host + ":" + port + "]") + ready(key) + } catch { + case e: IOException => { + logError("Error finishing connect to " + host + ":" + port) + close(key) + } + } + } + + def accept(key: SelectionKey) { + try { + val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] + val channel = serverChannel.accept() + val address = channel.socket.getRemoteSocketAddress + channel.configureBlocking(false) + logInfo("Accepted connection from [" + address + "]") + ready(channel.register(selector, 0)) + } catch { + case e: IOException => { + logError("Error accepting connection", e) + } + } + } + + def changeInterest(key: SelectionKey, ops: Int) { + logTrace("Added request to change ops to " + ops) + interestChangeRequests += ((key, ops)) + } + + def ready(key: SelectionKey) + + def preSelect() { + } + + def read(key: SelectionKey) { + throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) + } + + def write(key: SelectionKey) { + throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) + } + + def close(key: SelectionKey) { + try { + key.channel.close() + key.cancel() + Thread.currentThread.interrupt + } catch { + case e: Exception => logError("Error closing connection", e) + } + } +} diff --git a/streaming/src/main/scala/spark/stream/DumbTopKWordCount2_Special.scala b/streaming/src/main/scala/spark/stream/DumbTopKWordCount2_Special.scala new file mode 100644 index 0000000000..bd43f44b1a --- /dev/null +++ b/streaming/src/main/scala/spark/stream/DumbTopKWordCount2_Special.scala @@ -0,0 +1,138 @@ +package spark.stream + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.Queue + +import java.lang.{Long => JLong} + +object DumbTopKWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + + def add(v1: JLong, v2: JLong) = (v1 + v2) + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } + + + val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + wordCounts.persist(StorageLevel.MEMORY_ONLY) + val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) + + def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { + val taken = new Array[(String, JLong)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, JLong) = null + var swap: (String, JLong) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + /*println("count = " + count)*/ + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 10 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + + /* + windowedCounts.filter(_ == null).foreachRDD(rdd => { + val count = rdd.count + println("# of nulls = " + count) + })*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/stream/DumbWordCount2_Special.scala b/streaming/src/main/scala/spark/stream/DumbWordCount2_Special.scala new file mode 100644 index 0000000000..31d682348a --- /dev/null +++ b/streaming/src/main/scala/spark/stream/DumbWordCount2_Special.scala @@ -0,0 +1,92 @@ +package spark.stream + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import java.lang.{Long => JLong} +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + +object DumbWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + def add(v1: JLong, v2: JLong) = (v1 + v2) + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + + map.toIterator + } + + val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + wordCounts.persist(StorageLevel.MEMORY_ONLY) + val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/stream/FileStreamReceiver.scala b/streaming/src/main/scala/spark/stream/FileStreamReceiver.scala new file mode 100644 index 0000000000..026254d6e1 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/FileStreamReceiver.scala @@ -0,0 +1,70 @@ +package spark.stream + +import spark.Logging + +import scala.collection.mutable.HashSet +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +class FileStreamReceiver ( + inputName: String, + rootDirectory: String, + intervalDuration: Long) + extends Logging { + + val pollInterval = 100 + val sparkstreamScheduler = { + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + 1 + RemoteActor.select(Node(host, port), 'SparkStreamScheduler) + } + val directory = new Path(rootDirectory) + val fs = directory.getFileSystem(new Configuration()) + val files = new HashSet[String]() + var time: Long = 0 + + def start() { + fs.mkdirs(directory) + files ++= getFiles() + + actor { + logInfo("Monitoring directory - " + rootDirectory) + while(true) { + testFiles(getFiles()) + Thread.sleep(pollInterval) + } + } + } + + def getFiles(): Iterable[String] = { + fs.listStatus(directory).map(_.getPath.toString) + } + + def testFiles(fileList: Iterable[String]) { + fileList.foreach(file => { + if (!files.contains(file)) { + if (!file.endsWith("_tmp")) { + notifyFile(file) + } + files += file + } + }) + } + + def notifyFile(file: String) { + logInfo("Notifying file " + file) + time += intervalDuration + val interval = Interval(LongTime(time), LongTime(time + intervalDuration)) + sparkstreamScheduler ! InputGenerated(inputName, interval, file) + } +} + + diff --git a/streaming/src/main/scala/spark/stream/GrepCount.scala b/streaming/src/main/scala/spark/stream/GrepCount.scala new file mode 100644 index 0000000000..45b90d4837 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/GrepCount.scala @@ -0,0 +1,39 @@ +package spark.stream + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object GrepCount { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: GrepCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "GrepCount") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + //sentences.print + val matching = sentences.filter(_.contains("light")) + matching.foreachRDD(rdd => println(rdd.count)) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/stream/GrepCount2.scala b/streaming/src/main/scala/spark/stream/GrepCount2.scala new file mode 100644 index 0000000000..4eb65ba906 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/GrepCount2.scala @@ -0,0 +1,113 @@ +package spark.stream + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkEnv +import spark.SparkContext +import spark.storage.StorageLevel +import spark.network.Message +import spark.network.ConnectionManagerId + +import java.nio.ByteBuffer + +object GrepCount2 { + + def startSparkEnvs(sc: SparkContext) { + + val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) + sc.runJob(dummy, (_: Iterator[Int]) => {}) + + println("SparkEnvs started") + Thread.sleep(1000) + /*sc.runJob(sc.parallelize(0 to 1000, 100), (_: Iterator[Int]) => {})*/ + } + + def warmConnectionManagers(sc: SparkContext) { + val slaveConnManagerIds = sc.parallelize(0 to 100, 100).map( + i => SparkEnv.get.connectionManager.id).collect().distinct + println("\nSlave ConnectionManagerIds") + slaveConnManagerIds.foreach(println) + println + + Thread.sleep(1000) + val numSlaves = slaveConnManagerIds.size + val count = 3 + val size = 5 * 1024 * 1024 + val iterations = (500 * 1024 * 1024 / (numSlaves * size)).toInt + println("count = " + count + ", size = " + size + ", iterations = " + iterations) + + (0 until count).foreach(i => { + val resultStrs = sc.parallelize(0 until numSlaves, numSlaves).map(i => { + val connManager = SparkEnv.get.connectionManager + val thisConnManagerId = connManager.id + /*connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + println("Received [" + msg + "] from [" + id + "]") + None + })*/ + + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + val futures = (0 until iterations).map(i => { + slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + println("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") + connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) + }) + }).flatMap(x => x) + val results = futures.map(f => f()) + val finishTime = System.currentTimeMillis + + + val mb = size * results.size / 1024.0 / 1024.0 + val ms = finishTime - startTime + + val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" + println(resultStr) + System.gc() + resultStr + }).collect() + + println("---------------------") + println("Run " + i) + resultStrs.foreach(println) + println("---------------------") + }) + } + + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: GrepCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "GrepCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + /*startSparkEnvs(ssc.sc)*/ + warmConnectionManagers(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-"+i, 500)).toArray + ) + + val matching = sentences.filter(_.contains("light")) + matching.foreachRDD(rdd => println(rdd.count)) + + ssc.run + } +} + + + + diff --git a/streaming/src/main/scala/spark/stream/GrepCountApprox.scala b/streaming/src/main/scala/spark/stream/GrepCountApprox.scala new file mode 100644 index 0000000000..a4be2cc936 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/GrepCountApprox.scala @@ -0,0 +1,54 @@ +package spark.stream + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object GrepCountApprox { + var inputFile : String = null + var hdfs : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 5) { + println ("Usage: GrepCountApprox ") + System.exit(1) + } + + hdfs = args(1) + inputFile = hdfs + args(2) + idealPartitions = args(3).toInt + val timeout = args(4).toLong + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "GrepCount") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + ssc.setTempDir(hdfs + "/tmp") + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + //sentences.print + val matching = sentences.filter(_.contains("light")) + var i = 0 + val startTime = System.currentTimeMillis + matching.foreachRDD { rdd => + val myNum = i + val result = rdd.countApprox(timeout) + val initialTime = (System.currentTimeMillis - startTime) / 1000.0 + printf("APPROX\t%.2f\t%d\tinitial\t%.1f\t%.1f\n", initialTime, myNum, result.initialValue.mean, + result.initialValue.high - result.initialValue.low) + result.onComplete { r => + val finalTime = (System.currentTimeMillis - startTime) / 1000.0 + printf("APPROX\t%.2f\t%d\tfinal\t%.1f\t0.0\t%.1f\n", finalTime, myNum, r.mean, finalTime - initialTime) + } + i += 1 + } + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/stream/IdealPerformance.scala b/streaming/src/main/scala/spark/stream/IdealPerformance.scala new file mode 100644 index 0000000000..589fb2def0 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/IdealPerformance.scala @@ -0,0 +1,36 @@ +package spark.stream + +import scala.collection.mutable.Map + +object IdealPerformance { + val base: String = "The medium researcher counts around the pinched troop The empire breaks " + + "Matei Matei announces HY with a theorem " + + def main (args: Array[String]) { + val sentences: String = base * 100000 + + for (i <- 1 to 30) { + val start = System.nanoTime + + val words = sentences.split(" ") + + val pairs = words.map(word => (word, 1)) + + val counts = Map[String, Int]() + + println("Job " + i + " position A at " + (System.nanoTime - start) / 1e9) + + pairs.foreach((pair) => { + var t = counts.getOrElse(pair._1, 0) + counts(pair._1) = t + pair._2 + }) + println("Job " + i + " position B at " + (System.nanoTime - start) / 1e9) + + for ((word, count) <- counts) { + print(word + " " + count + "; ") + } + println + println("Job " + i + " finished in " + (System.nanoTime - start) / 1e9) + } + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/stream/Interval.scala b/streaming/src/main/scala/spark/stream/Interval.scala new file mode 100644 index 0000000000..08d0ed95b4 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/Interval.scala @@ -0,0 +1,75 @@ +package spark.stream + +case class Interval (val beginTime: Time, val endTime: Time) { + + def this(beginMs: Long, endMs: Long) = this(new LongTime(beginMs), new LongTime(endMs)) + + def duration(): Time = endTime - beginTime + + def += (time: Time) { + beginTime += time + endTime += time + this + } + + def + (time: Time): Interval = { + new Interval(beginTime + time, endTime + time) + } + + def < (that: Interval): Boolean = { + if (this.duration != that.duration) { + throw new Exception("Comparing two intervals with different durations [" + this + ", " + that + "]") + } + this.endTime < that.endTime + } + + def <= (that: Interval) = (this < that || this == that) + + def > (that: Interval) = !(this <= that) + + def >= (that: Interval) = !(this < that) + + def next(): Interval = { + this + (endTime - beginTime) + } + + def isZero() = (beginTime.isZero && endTime.isZero) + + def toFormattedString = beginTime.toFormattedString + "-" + endTime.toFormattedString + + override def toString = "[" + beginTime + ", " + endTime + "]" +} + +object Interval { + + /* + implicit def longTupleToInterval (longTuple: (Long, Long)) = + Interval(longTuple._1, longTuple._2) + + implicit def intTupleToInterval (intTuple: (Int, Int)) = + Interval(intTuple._1, intTuple._2) + + implicit def string2Interval (str: String): Interval = { + val parts = str.split(",") + if (parts.length == 1) + return Interval.zero + return Interval (parts(0).toInt, parts(1).toInt) + } + + def getInterval (timeMs: Long, intervalDurationMs: Long): Interval = { + val intervalBeginMs = timeMs / intervalDurationMs * intervalDurationMs + Interval(intervalBeginMs, intervalBeginMs + intervalDurationMs) + } + */ + + def zero() = new Interval (Time.zero, Time.zero) + + def currentInterval(intervalDuration: LongTime): Interval = { + val time = LongTime(System.currentTimeMillis) + val intervalBegin = time.floor(intervalDuration) + Interval(intervalBegin, intervalBegin + intervalDuration) + } + +} + + diff --git a/streaming/src/main/scala/spark/stream/Job.scala b/streaming/src/main/scala/spark/stream/Job.scala new file mode 100644 index 0000000000..bfdd5db645 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/Job.scala @@ -0,0 +1,21 @@ +package spark.stream + +class Job(val time: Time, func: () => _) { + val id = Job.getNewId() + + def run() { + func() + } + + override def toString = "SparkStream Job " + id + ":" + time +} + +object Job { + var lastId = 1 + + def getNewId() = synchronized { + lastId += 1 + lastId + } +} + diff --git a/streaming/src/main/scala/spark/stream/JobManager.scala b/streaming/src/main/scala/spark/stream/JobManager.scala new file mode 100644 index 0000000000..5ea80b92aa --- /dev/null +++ b/streaming/src/main/scala/spark/stream/JobManager.scala @@ -0,0 +1,112 @@ +package spark.stream + +import spark.SparkEnv +import spark.Logging + +import scala.collection.mutable.PriorityQueue +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ +import scala.actors.scheduler.ResizableThreadPoolScheduler +import scala.actors.scheduler.ForkJoinScheduler + +sealed trait JobManagerMessage +case class RunJob(job: Job) extends JobManagerMessage +case class JobCompleted(handlerId: Int) extends JobManagerMessage + +class JobHandler(ssc: SparkStreamContext, val id: Int) extends DaemonActor with Logging { + + var busy = false + + def act() { + loop { + receive { + case job: Job => { + SparkEnv.set(ssc.env) + try { + logInfo("Starting " + job) + job.run() + logInfo("Finished " + job) + if (job.time.isInstanceOf[LongTime]) { + val longTime = job.time.asInstanceOf[LongTime] + logInfo("Total pushing + skew + processing delay for " + longTime + " is " + + (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") + } + } catch { + case e: Exception => logError("SparkStream job failed", e) + } + busy = false + reply(JobCompleted(id)) + } + } + } + } +} + +class JobManager(ssc: SparkStreamContext, numThreads: Int = 2) extends DaemonActor with Logging { + + implicit private val jobOrdering = new Ordering[Job] { + override def compare(job1: Job, job2: Job): Int = { + if (job1.time < job2.time) { + return 1 + } else if (job2.time < job1.time) { + return -1 + } else { + return 0 + } + } + } + + private val jobs = new PriorityQueue[Job]() + private val handlers = (0 until numThreads).map(i => new JobHandler(ssc, i)) + + def act() { + handlers.foreach(_.start) + loop { + receive { + case RunJob(job) => { + jobs += job + logInfo("Job " + job + " submitted") + runJob() + } + case JobCompleted(handlerId) => { + runJob() + } + } + } + } + + def runJob(): Unit = { + logInfo("Attempting to allocate job ") + if (jobs.size > 0) { + handlers.find(!_.busy).foreach(handler => { + val job = jobs.dequeue + logInfo("Allocating job " + job + " to handler " + handler.id) + handler.busy = true + handler ! job + }) + } + } +} + +object JobManager { + def main(args: Array[String]) { + val ssc = new SparkStreamContext("local[4]", "JobManagerTest") + val jobManager = new JobManager(ssc) + jobManager.start() + + val t = System.currentTimeMillis + for (i <- 1 to 10) { + jobManager ! RunJob(new Job( + LongTime(i), + () => { + Thread.sleep(500) + println("Job " + i + " took " + (System.currentTimeMillis - t) + " ms") + } + )) + } + Thread.sleep(6000) + } +} + diff --git a/streaming/src/main/scala/spark/stream/JobManager2.scala b/streaming/src/main/scala/spark/stream/JobManager2.scala new file mode 100644 index 0000000000..b69653b9a4 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/JobManager2.scala @@ -0,0 +1,37 @@ +package spark.stream + +import spark.{Logging, SparkEnv} +import java.util.concurrent.Executors + + +class JobManager2(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { + + class JobHandler(ssc: SparkStreamContext, job: Job) extends Runnable { + def run() { + SparkEnv.set(ssc.env) + try { + logInfo("Starting " + job) + job.run() + logInfo("Finished " + job) + if (job.time.isInstanceOf[LongTime]) { + val longTime = job.time.asInstanceOf[LongTime] + logInfo("Total notification + skew + processing delay for " + longTime + " is " + + (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") + if (System.getProperty("spark.stream.distributed", "false") == "true") { + TestInputBlockTracker.setEndTime(job.time) + } + } + } catch { + case e: Exception => logError("SparkStream job failed", e) + } + } + } + + initLogging() + + val jobExecutor = Executors.newFixedThreadPool(numThreads) + + def runJob(job: Job) { + jobExecutor.execute(new JobHandler(ssc, job)) + } +} diff --git a/streaming/src/main/scala/spark/stream/NetworkStreamReceiver.scala b/streaming/src/main/scala/spark/stream/NetworkStreamReceiver.scala new file mode 100644 index 0000000000..8be46cc927 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/NetworkStreamReceiver.scala @@ -0,0 +1,184 @@ +package spark.stream + +import spark.Logging +import spark.storage.StorageLevel + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer} +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.io.BufferedWriter +import java.io.OutputStreamWriter + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +/*import akka.actor.Actor._*/ + +class NetworkStreamReceiver[T: ClassManifest] ( + inputName: String, + intervalDuration: Time, + splitId: Int, + ssc: SparkStreamContext, + tempDirectory: String) + extends DaemonActor + with Logging { + + /** + * Assume all data coming in has non-decreasing timestamp. + */ + final class Inbox[T: ClassManifest] (intervalDuration: Time) { + var currentBucket: (Interval, ArrayBuffer[T]) = null + val filledBuckets = new Queue[(Interval, ArrayBuffer[T])]() + + def += (tuple: (Time, T)) = addTuple(tuple) + + def addTuple(tuple: (Time, T)) { + val (time, data) = tuple + val interval = getInterval (time) + + filledBuckets.synchronized { + if (currentBucket == null) { + currentBucket = (interval, new ArrayBuffer[T]()) + } + + if (interval != currentBucket._1) { + filledBuckets += currentBucket + currentBucket = (interval, new ArrayBuffer[T]()) + } + + currentBucket._2 += data + } + } + + def getInterval(time: Time): Interval = { + val intervalBegin = time.floor(intervalDuration) + Interval (intervalBegin, intervalBegin + intervalDuration) + } + + def hasFilledBuckets(): Boolean = { + filledBuckets.synchronized { + return filledBuckets.size > 0 + } + } + + def popFilledBucket(): (Interval, ArrayBuffer[T]) = { + filledBuckets.synchronized { + if (filledBuckets.size == 0) { + return null + } + return filledBuckets.dequeue() + } + } + } + + val inbox = new Inbox[T](intervalDuration) + lazy val sparkstreamScheduler = { + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + val url = "akka://spark@%s:%s/user/SparkStreamScheduler".format(host, port) + ssc.actorSystem.actorFor(url) + } + /*sparkstreamScheduler ! Test()*/ + + val intervalDurationMillis = intervalDuration.asInstanceOf[LongTime].milliseconds + val useBlockManager = true + + initLogging() + + override def act() { + // register the InputReceiver + val port = 7078 + RemoteActor.alive(port) + RemoteActor.register(Symbol("NetworkStreamReceiver-"+inputName), self) + logInfo("Registered actor on port " + port) + + loop { + reactWithin (getSleepTime) { + case TIMEOUT => + flushInbox() + case data => + val t = data.asInstanceOf[T] + inbox += (getTimeFromData(t), t) + } + } + } + + def getSleepTime(): Long = { + (System.currentTimeMillis / intervalDurationMillis + 1) * + intervalDurationMillis - System.currentTimeMillis + } + + def getTimeFromData(data: T): Time = { + LongTime(System.currentTimeMillis) + } + + def flushInbox() { + while (inbox.hasFilledBuckets) { + inbox.synchronized { + val (interval, data) = inbox.popFilledBucket() + val dataArray = data.toArray + logInfo("Received " + dataArray.length + " items at interval " + interval) + val reference = { + if (useBlockManager) { + writeToBlockManager(dataArray, interval) + } else { + writeToDisk(dataArray, interval) + } + } + if (reference != null) { + logInfo("Notifying scheduler") + sparkstreamScheduler ! InputGenerated(inputName, interval, reference.toString) + } + } + } + } + + def writeToDisk(data: Array[T], interval: Interval): String = { + try { + // TODO(Haoyuan): For current test, the following writing to file lines could be + // commented. + val fs = new Path(tempDirectory).getFileSystem(new Configuration()) + val inputDir = new Path( + tempDirectory, + inputName + "-" + interval.toFormattedString) + val inputFile = new Path(inputDir, "part-" + splitId) + logInfo("Writing to file " + inputFile) + if (System.getProperty("spark.fake", "false") != "true") { + val writer = new BufferedWriter(new OutputStreamWriter(fs.create(inputFile, true))) + data.foreach(x => writer.write(x.toString + "\n")) + writer.close() + } else { + logInfo("Fake file") + } + inputFile.toString + }catch { + case e: Exception => + logError("Exception writing to file at interval " + interval + ": " + e.getMessage, e) + null + } + } + + def writeToBlockManager(data: Array[T], interval: Interval): String = { + try{ + val blockId = inputName + "-" + interval.toFormattedString + "-" + splitId + if (System.getProperty("spark.fake", "false") != "true") { + logInfo("Writing as block " + blockId ) + ssc.env.blockManager.put(blockId.toString, data.toIterator, StorageLevel.DISK_AND_MEMORY) + } else { + logInfo("Fake block") + } + blockId + } catch { + case e: Exception => + logError("Exception writing to block manager at interval " + interval + ": " + e.getMessage, e) + null + } + } +} diff --git a/streaming/src/main/scala/spark/stream/RDS.scala b/streaming/src/main/scala/spark/stream/RDS.scala new file mode 100644 index 0000000000..b83181b0d1 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/RDS.scala @@ -0,0 +1,607 @@ +package spark.stream + +import spark.stream.SparkStreamContext._ + +import spark.RDD +import spark.BlockRDD +import spark.UnionRDD +import spark.Logging +import spark.SparkContext +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import java.net.InetSocketAddress + +abstract class RDS[T: ClassManifest] (@transient val ssc: SparkStreamContext) +extends Logging with Serializable { + + initLogging() + + /* ---------------------------------------------- */ + /* Methods that must be implemented by subclasses */ + /* ---------------------------------------------- */ + + // Time by which the window slides in this RDS + def slideTime: Time + + // List of parent RDSs on which this RDS depends on + def dependencies: List[RDS[_]] + + // Key method that computes RDD for a valid time + def compute (validTime: Time): Option[RDD[T]] + + /* --------------------------------------- */ + /* Other general fields and methods of RDS */ + /* --------------------------------------- */ + + // Variable to store the RDDs generated earlier in time + @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () + + // Variable to be set to the first time seen by the RDS (effective time zero) + private[stream] var zeroTime: Time = null + + // Variable to specify storage level + private var storageLevel: StorageLevel = StorageLevel.NONE + + // Checkpoint level and checkpoint interval + private var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint + private var checkpointInterval: Time = null + + // Change this RDD's storage level + def persist( + storageLevel: StorageLevel, + checkpointLevel: StorageLevel, + checkpointInterval: Time): RDS[T] = { + if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { + // TODO: not sure this is necessary for RDSes + throw new UnsupportedOperationException( + "Cannot change storage level of an RDS after it was already assigned a level") + } + this.storageLevel = storageLevel + this.checkpointLevel = checkpointLevel + this.checkpointInterval = checkpointInterval + this + } + + def persist(newLevel: StorageLevel): RDS[T] = persist(newLevel, StorageLevel.NONE, null) + + // Turn on the default caching level for this RDD + def persist(): RDS[T] = persist(StorageLevel.MEMORY_ONLY_DESER) + + // Turn on the default caching level for this RDD + def cache(): RDS[T] = persist() + + def isInitialized = (zeroTime != null) + + // This method initializes the RDS by setting the "zero" time, based on which + // the validity of future times is calculated. This method also recursively initializes + // its parent RDSs. + def initialize(firstInterval: Interval) { + if (zeroTime == null) { + zeroTime = firstInterval.beginTime + } + logInfo(this + " initialized") + dependencies.foreach(_.initialize(firstInterval)) + } + + // This method checks whether the 'time' is valid wrt slideTime for generating RDD + private def isTimeValid (time: Time): Boolean = { + if (!isInitialized) + throw new Exception (this.toString + " has not been initialized") + if ((time - zeroTime).isMultipleOf(slideTime)) { + true + } else { + false + } + } + + // This method either retrieves a precomputed RDD of this RDS, + // or computes the RDD (if the time is valid) + def getOrCompute(time: Time): Option[RDD[T]] = { + + // if RDD was already generated, then retrieve it from HashMap + generatedRDDs.get(time) match { + + // If an RDD was already generated and is being reused, then + // probably all RDDs in this RDS will be reused and hence should be cached + case Some(oldRDD) => Some(oldRDD) + + // if RDD was not generated, and if the time is valid + // (based on sliding time of this RDS), then generate the RDD + case None => + if (isTimeValid(time)) { + compute(time) match { + case Some(newRDD) => + if (System.getProperty("spark.fake", "false") != "true" || + newRDD.getStorageLevel == StorageLevel.NONE) { + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + } + } + generatedRDDs.put(time.copy(), newRDD) + Some(newRDD) + case None => + None + } + } else { + None + } + } + } + + // This method generates a SparkStream job for the given time + // and may require to be overriden by subclasses + def generateJob(time: Time): Option[Job] = { + getOrCompute(time) match { + case Some(rdd) => { + val jobFunc = () => { + val emptyFunc = { (iterator: Iterator[T]) => {} } + ssc.sc.runJob(rdd, emptyFunc) + } + Some(new Job(time, jobFunc)) + } + case None => None + } + } + + /* -------------- */ + /* RDS operations */ + /* -------------- */ + + def map[U: ClassManifest](mapFunc: T => U) = new MappedRDS(this, ssc.sc.clean(mapFunc)) + + def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = + new FlatMappedRDS(this, ssc.sc.clean(flatMapFunc)) + + def filter(filterFunc: T => Boolean) = new FilteredRDS(this, filterFunc) + + def glom() = new GlommedRDS(this) + + def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = + new MapPartitionedRDS(this, ssc.sc.clean(mapPartFunc)) + + def reduce(reduceFunc: (T, T) => T) = this.map(x => (1, x)).reduceByKey(reduceFunc, 1).map(_._2) + + def count() = this.map(_ => 1).reduce(_ + _) + + def collect() = this.map(x => (1, x)).groupByKey(1).map(_._2) + + def foreach(foreachFunc: T => Unit) = { + val newrds = new PerElementForEachRDS(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newrds) + newrds + } + + def foreachRDD(foreachFunc: RDD[T] => Unit) = { + val newrds = new PerRDDForEachRDS(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newrds) + newrds + } + + def print() = { + def foreachFunc = (rdd: RDD[T], time: Time) => { + val first11 = rdd.take(11) + println ("-------------------------------------------") + println ("Time: " + time) + println ("-------------------------------------------") + first11.take(10).foreach(println) + if (first11.size > 10) println("...") + println() + } + val newrds = new PerRDDForEachRDS(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newrds) + newrds + } + + def window(windowTime: Time, slideTime: Time) = new WindowedRDS(this, windowTime, slideTime) + + def batch(batchTime: Time) = window(batchTime, batchTime) + + def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time) = + this.window(windowTime, slideTime).reduce(reduceFunc) + + def reduceByWindow( + reduceFunc: (T, T) => T, + invReduceFunc: (T, T) => T, + windowTime: Time, + slideTime: Time) = { + this.map(x => (1, x)) + .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1) + .map(_._2) + } + + def countByWindow(windowTime: Time, slideTime: Time) = { + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime) + } + + def union(that: RDS[T]) = new UnifiedRDS(Array(this, that)) + + def register() = ssc.registerOutputStream(this) +} + + +class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) +extends Serializable { + + def ssc = rds.ssc + + /* ---------------------------------- */ + /* RDS operations for key-value pairs */ + /* ---------------------------------- */ + + def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + def createCombiner(v: V) = ArrayBuffer[V](v) + def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) + def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) + combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) + } + + private def combineByKey[C: ClassManifest]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + numPartitions: Int) : ShuffledRDS[K, V, C] = { + new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def groupByKeyAndWindow( + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + rds.window(windowTime, slideTime).groupByKey(numPartitions) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) + } + + // This method is the efficient sliding window reduce operation, + // which requires the specification of an inverse reduce function, + // so that new elements introduced in the window can be "added" using + // reduceFunc to the previous window's result and old elements can be + // "subtracted using invReduceFunc. + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int): ReducedWindowedRDS[K, V] = { + + new ReducedWindowedRDS[K, V]( + rds, + ssc.sc.clean(reduceFunc), + ssc.sc.clean(invReduceFunc), + windowTime, + slideTime, + numPartitions) + } +} + + +abstract class InputRDS[T: ClassManifest] ( + val inputName: String, + val batchDuration: Time, + ssc: SparkStreamContext) +extends RDS[T](ssc) { + + override def dependencies = List() + + override def slideTime = batchDuration + + def setReference(time: Time, reference: AnyRef) +} + + +class FileInputRDS( + val fileInputName: String, + val directory: String, + ssc: SparkStreamContext) +extends InputRDS[String](fileInputName, LongTime(1000), ssc) { + + @transient val generatedFiles = new HashMap[Time,String] + + // TODO(Haoyuan): This is for the performance test. + @transient + val rdd = ssc.sc.textFile(SparkContext.inputFile, + SparkContext.idealPartitions).asInstanceOf[RDD[String]] + + override def compute(validTime: Time): Option[RDD[String]] = { + generatedFiles.get(validTime) match { + case Some(file) => + logInfo("Reading from file " + file + " for time " + validTime) + // Some(ssc.sc.textFile(file).asInstanceOf[RDD[String]]) + // The following line is for HDFS performance test. Sould comment out the above line. + Some(rdd) + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + generatedFiles += ((time, reference.toString)) + logInfo("Reference added for time " + time + " - " + reference.toString) + } +} + +class NetworkInputRDS[T: ClassManifest]( + val networkInputName: String, + val addresses: Array[InetSocketAddress], + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[T](networkInputName, batchDuration, ssc) { + + + // TODO(Haoyuan): This is for the performance test. + @transient var rdd: RDD[T] = null + + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Running initial count to cache fake RDD") + rdd = ssc.sc.textFile(SparkContext.inputFile, + SparkContext.idealPartitions).asInstanceOf[RDD[T]] + val fakeCacheLevel = System.getProperty("spark.fake.cache", "") + if (fakeCacheLevel == "MEMORY_ONLY_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel != "") { + logError("Invalid fake cache level: " + fakeCacheLevel) + System.exit(1) + } + rdd.count() + } + + @transient val references = new HashMap[Time,String] + + override def compute(validTime: Time): Option[RDD[T]] = { + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Returning fake RDD at " + validTime) + return Some(rdd) + } + references.get(validTime) match { + case Some(reference) => + if (reference.startsWith("file") || reference.startsWith("hdfs")) { + logInfo("Reading from file " + reference + " for time " + validTime) + Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) + } else { + logInfo("Getting from BlockManager " + reference + " for time " + validTime) + Some(new BlockRDD(ssc.sc, Array(reference))) + } + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.toString)) + logInfo("Reference added for time " + time + " - " + reference.toString) + } +} + + +class TestInputRDS( + val testInputName: String, + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[String](testInputName, batchDuration, ssc) { + + @transient val references = new HashMap[Time,Array[String]] + + override def compute(validTime: Time): Option[RDD[String]] = { + references.get(validTime) match { + case Some(reference) => + Some(new BlockRDD[String](ssc.sc, reference)) + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.asInstanceOf[Array[String]])) + } +} + + +class MappedRDS[T: ClassManifest, U: ClassManifest] ( + parent: RDS[T], + mapFunc: T => U) +extends RDS[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.map[U](mapFunc)) + } +} + + +class FlatMappedRDS[T: ClassManifest, U: ClassManifest]( + parent: RDS[T], + flatMapFunc: T => Traversable[U]) +extends RDS[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) + } +} + + +class FilteredRDS[T: ClassManifest](parent: RDS[T], filterFunc: T => Boolean) +extends RDS[T](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + parent.getOrCompute(validTime).map(_.filter(filterFunc)) + } +} + +class MapPartitionedRDS[T: ClassManifest, U: ClassManifest]( + parent: RDS[T], + mapPartFunc: Iterator[T] => Iterator[U]) +extends RDS[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc)) + } +} + +class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Array[T]]] = { + parent.getOrCompute(validTime).map(_.glom()) + } +} + + +class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( + parent: RDS[(K,V)], + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + numPartitions: Int) + extends RDS [(K,C)] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K,C)]] = { + parent.getOrCompute(validTime) match { + case Some(rdd) => + val newrdd = { + if (numPartitions > 0) { + rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, numPartitions) + } else { + rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner) + } + } + Some(newrdd) + case None => None + } + } +} + + +class UnifiedRDS[T: ClassManifest](parents: Array[RDS[T]]) +extends RDS[T](parents(0).ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different SparkStreamContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime: Time = parents(0).slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val rdds = new ArrayBuffer[RDD[T]]() + parents.map(_.getOrCompute(validTime)).foreach(_ match { + case Some(rdd) => rdds += rdd + case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) + }) + if (rdds.size > 0) { + Some(new UnionRDD(ssc.sc, rdds)) + } else { + None + } + } +} + + +class PerElementForEachRDS[T: ClassManifest] ( + parent: RDS[T], + foreachFunc: T => Unit) +extends RDS[Unit](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + val sparkJobFunc = { + (iterator: Iterator[T]) => iterator.foreach(foreachFunc) + } + ssc.sc.runJob(rdd, sparkJobFunc) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} + + +class PerRDDForEachRDS[T: ClassManifest] ( + parent: RDS[T], + foreachFunc: (RDD[T], Time) => Unit) +extends RDS[Unit](parent.ssc) { + + def this(parent: RDS[T], altForeachFunc: (RDD[T]) => Unit) = + this(parent, (rdd: RDD[T], time: Time) => altForeachFunc(rdd)) + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + foreachFunc(rdd, time) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/spark/stream/ReducedWindowedRDS.scala b/streaming/src/main/scala/spark/stream/ReducedWindowedRDS.scala new file mode 100644 index 0000000000..d47654ccb9 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/ReducedWindowedRDS.scala @@ -0,0 +1,218 @@ +package spark.stream + +import spark.stream.SparkStreamContext._ + +import spark.RDD +import spark.UnionRDD +import spark.CoGroupedRDD +import spark.HashPartitioner +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer + +class ReducedWindowedRDS[K: ClassManifest, V: ClassManifest]( + parent: RDS[(K, V)], + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + _windowTime: Time, + _slideTime: Time, + numPartitions: Int) +extends RDS[(K,V)](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of ReducedWindowedRDS (" + _slideTime + ") " + + "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of ReducedWindowedRDS (" + _slideTime + ") " + + "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") + + val reducedRDS = parent.reduceByKey(reduceFunc, numPartitions) + val allowPartialWindows = true + //reducedRDS.persist(StorageLevel.MEMORY_ONLY_DESER_2) + + override def dependencies = List(reducedRDS) + + def windowTime: Time = _windowTime + + override def slideTime: Time = _slideTime + + override def persist( + storageLevel: StorageLevel, + checkpointLevel: StorageLevel, + checkpointInterval: Time): RDS[(K,V)] = { + super.persist(storageLevel, checkpointLevel, checkpointInterval) + reducedRDS.persist(storageLevel, checkpointLevel, checkpointInterval) + } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + + + // Notation: + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old time steps new time steps + // + def getAdjustedWindow(endTime: Time, windowTime: Time): Interval = { + val beginTime = + if (allowPartialWindows && endTime - windowTime < parent.zeroTime) { + parent.zeroTime + } else { + endTime - windowTime + } + Interval(beginTime, endTime) + } + + val currentTime = validTime.copy + val currentWindow = getAdjustedWindow(currentTime, windowTime) + val previousWindow = getAdjustedWindow(currentTime - slideTime, windowTime) + + logInfo("Current window = " + currentWindow) + logInfo("Previous window = " + previousWindow) + logInfo("Parent.zeroTime = " + parent.zeroTime) + + if (allowPartialWindows) { + if (currentTime - slideTime == parent.zeroTime) { + reducedRDS.getOrCompute(currentTime) match { + case Some(rdd) => return Some(rdd) + case None => throw new Exception("Could not get first reduced RDD for time " + currentTime) + } + } + } else { + if (previousWindow.beginTime < parent.zeroTime) { + if (currentWindow.beginTime < parent.zeroTime) { + return None + } else { + // If this is the first feasible window, then generate reduced value in the naive manner + val reducedRDDs = new ArrayBuffer[RDD[(K, V)]]() + var t = currentWindow.endTime + while (t > currentWindow.beginTime) { + reducedRDS.getOrCompute(t) match { + case Some(rdd) => reducedRDDs += rdd + case None => throw new Exception("Could not get reduced RDD for time " + t) + } + t -= reducedRDS.slideTime + } + if (reducedRDDs.size == 0) { + throw new Exception("Could not generate the first RDD for time " + validTime) + } + return Some(new UnionRDD(ssc.sc, reducedRDDs).reduceByKey(reduceFunc, numPartitions)) + } + } + } + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = getOrCompute(previousWindow.endTime) match { + case Some(rdd) => rdd.asInstanceOf[RDD[(_, _)]] + case None => throw new Exception("Could not get previous RDD for time " + previousWindow.endTime) + } + + val oldRDDs = new ArrayBuffer[RDD[(_, _)]]() + val newRDDs = new ArrayBuffer[RDD[(_, _)]]() + + // Get the RDDs of the reduced values in "old time steps" + var t = currentWindow.beginTime + while (t > previousWindow.beginTime) { + reducedRDS.getOrCompute(t) match { + case Some(rdd) => oldRDDs += rdd.asInstanceOf[RDD[(_, _)]] + case None => throw new Exception("Could not get old reduced RDD for time " + t) + } + t -= reducedRDS.slideTime + } + + // Get the RDDs of the reduced values in "new time steps" + t = currentWindow.endTime + while (t > previousWindow.endTime) { + reducedRDS.getOrCompute(t) match { + case Some(rdd) => newRDDs += rdd.asInstanceOf[RDD[(_, _)]] + case None => throw new Exception("Could not get new reduced RDD for time " + t) + } + t -= reducedRDS.slideTime + } + + val partitioner = new HashPartitioner(numPartitions) + val allRDDs = new ArrayBuffer[RDD[(_, _)]]() + allRDDs += previousWindowRDD + allRDDs ++= oldRDDs + allRDDs ++= newRDDs + + + val numOldRDDs = oldRDDs.size + val numNewRDDs = newRDDs.size + logInfo("Generated numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) + logInfo("Generating CoGroupedRDD with " + allRDDs.size + " RDDs") + val newRDD = new CoGroupedRDD[K](allRDDs.toSeq, partitioner).asInstanceOf[RDD[(K,Seq[Seq[V]])]].map(x => { + val (key, value) = x + logDebug("value.size = " + value.size + ", numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) + if (value.size != 1 + numOldRDDs + numNewRDDs) { + throw new Exception("Number of groups not odd!") + } + + // old values = reduced values of the "old time steps" that are eliminated from current window + // new values = reduced values of the "new time steps" that are introduced to the current window + // previous value = reduced value of the previous window + + /*val numOldValues = (value.size - 1) / 2*/ + // Getting reduced values "old time steps" + val oldValues = + (0 until numOldRDDs).map(i => value(1 + i)).filter(_.size > 0).map(x => x(0)) + // Getting reduced values "new time steps" + val newValues = + (0 until numNewRDDs).map(i => value(1 + numOldRDDs + i)).filter(_.size > 0).map(x => x(0)) + + // If reduced value for the key does not exist in previous window, it should not exist in "old time steps" + if (value(0).size == 0 && oldValues.size != 0) { + throw new Exception("Unexpected: Key exists in old reduced values but not in previous reduced values") + } + + // For the key, at least one of "old time steps", "new time steps" and previous window should have reduced values + if (value(0).size == 0 && oldValues.size == 0 && newValues.size == 0) { + throw new Exception("Unexpected: Key does not exist in any of old, new, or previour reduced values") + } + + // Logic to generate the final reduced value for current window: + // + // If previous window did not have reduced value for the key + // Then, return reduced value of "new time steps" as the final value + // Else, reduced value exists in previous window + // If "old" time steps did not have reduced value for the key + // Then, reduce previous window's reduced value with that of "new time steps" for final value + // Else, reduced values exists in "old time steps" + // If "new values" did not have reduced value for the key + // Then, inverse-reduce "old values" from previous window's reduced value for final value + // Else, all 3 values exist, combine all of them together + // + logDebug("# old values = " + oldValues.size + ", # new values = " + newValues) + val finalValue = { + if (value(0).size == 0) { + newValues.reduce(reduceFunc) + } else { + val prevValue = value(0)(0) + logDebug("prev value = " + prevValue) + if (oldValues.size == 0) { + // assuming newValue.size > 0 (all 3 cannot be zero, as checked earlier) + val temp = newValues.reduce(reduceFunc) + reduceFunc(prevValue, temp) + } else if (newValues.size == 0) { + invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) + } else { + val tempValue = invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) + reduceFunc(tempValue, newValues.reduce(reduceFunc)) + } + } + } + (key, finalValue) + }) + //newRDD.persist(StorageLevel.MEMORY_ONLY_DESER_2) + Some(newRDD) + } +} + + diff --git a/streaming/src/main/scala/spark/stream/Scheduler.scala b/streaming/src/main/scala/spark/stream/Scheduler.scala new file mode 100644 index 0000000000..38946fef11 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/Scheduler.scala @@ -0,0 +1,181 @@ +package spark.stream + +import spark.SparkEnv +import spark.Logging + +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.collection.mutable.ArrayBuffer + +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ +import akka.util.duration._ + +sealed trait SchedulerMessage +case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage +case class Test extends SchedulerMessage + +class Scheduler( + ssc: SparkStreamContext, + inputRDSs: Array[InputRDS[_]], + outputRDSs: Array[RDS[_]]) +extends Actor with Logging { + + class InputState (inputNames: Array[String]) { + val inputsLeft = new HashSet[String]() + inputsLeft ++= inputNames + + val startTime = System.currentTimeMillis + + def delay() = System.currentTimeMillis - startTime + + def addGeneratedInput(inputName: String) = inputsLeft -= inputName + + def areAllInputsGenerated() = (inputsLeft.size == 0) + + override def toString(): String = { + val left = if (inputsLeft.size == 0) "" else inputsLeft.reduceLeft(_ + ", " + _) + return "Inputs left = [ " + left + " ]" + } + } + + + initLogging() + + val inputNames = inputRDSs.map(_.inputName).toArray + val inputStates = new HashMap[Interval, InputState]() + val currentJobs = System.getProperty("spark.stream.currentJobs", "1").toInt + val jobManager = new JobManager2(ssc, currentJobs) + + // TODO(Haoyuan): The following line is for performance test only. + var cnt: Int = System.getProperty("spark.stream.fake.cnt", "60").toInt + var lastInterval: Interval = null + + + /*remote.register("SparkStreamScheduler", actorOf[Scheduler])*/ + logInfo("Registered actor on port ") + + /*jobManager.start()*/ + startStreamReceivers() + + def receive = { + case InputGenerated(inputName, interval, reference) => { + addGeneratedInput(inputName, interval, reference) + } + case Test() => logInfo("TEST PASSED") + } + + def addGeneratedInput(inputName: String, interval: Interval, reference: AnyRef = null) { + logInfo("Input " + inputName + " generated for interval " + interval) + inputStates.get(interval) match { + case None => inputStates.put(interval, new InputState(inputNames)) + case _ => + } + inputStates(interval).addGeneratedInput(inputName) + + inputRDSs.filter(_.inputName == inputName).foreach(inputRDS => { + inputRDS.setReference(interval.endTime, reference) + if (inputRDS.isInstanceOf[TestInputRDS]) { + TestInputBlockTracker.addBlocks(interval.endTime, reference) + } + } + ) + + def getNextInterval(): Option[Interval] = { + logDebug("Last interval is " + lastInterval) + val readyIntervals = inputStates.filter(_._2.areAllInputsGenerated).keys + /*inputState.foreach(println) */ + logDebug("InputState has " + inputStates.size + " intervals, " + readyIntervals.size + " ready intervals") + return readyIntervals.find(lastInterval == null || _.beginTime == lastInterval.endTime) + } + + var nextInterval = getNextInterval() + var count = 0 + while(nextInterval.isDefined) { + val inputState = inputStates.get(nextInterval.get).get + generateRDDsForInterval(nextInterval.get) + logInfo("Skew delay for " + nextInterval.get.endTime + " is " + (inputState.delay / 1000.0) + " s") + inputStates.remove(nextInterval.get) + lastInterval = nextInterval.get + nextInterval = getNextInterval() + count += 1 + /*if (nextInterval.size == 0 && inputState.size > 0) { + logDebug("Next interval not ready, pending intervals " + inputState.size) + }*/ + } + logDebug("RDDs generated for " + count + " intervals") + + /* + if (inputState(interval).areAllInputsGenerated) { + generateRDDsForInterval(interval) + lastInterval = interval + inputState.remove(interval) + } else { + logInfo("All inputs not generated for interval " + interval) + } + */ + } + + def generateRDDsForInterval (interval: Interval) { + logInfo("Generating RDDs for interval " + interval) + outputRDSs.foreach(outputRDS => { + if (!outputRDS.isInitialized) outputRDS.initialize(interval) + outputRDS.generateJob(interval.endTime) match { + case Some(job) => submitJob(job) + case None => + } + } + ) + // TODO(Haoyuan): This comment is for performance test only. + if (System.getProperty("spark.fake", "false") == "true" || System.getProperty("spark.stream.fake", "false") == "true") { + cnt -= 1 + if (cnt <= 0) { + logInfo("My time is up! " + cnt) + System.exit(1) + } + } + } + + def submitJob(job: Job) { + logInfo("Submitting " + job + " to JobManager") + /*jobManager ! RunJob(job)*/ + jobManager.runJob(job) + } + + def startStreamReceivers() { + val testStreamReceiverNames = new ArrayBuffer[(String, Long)]() + inputRDSs.foreach (inputRDS => { + inputRDS match { + case fileInputRDS: FileInputRDS => { + val fileStreamReceiver = new FileStreamReceiver( + fileInputRDS.inputName, + fileInputRDS.directory, + fileInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds) + fileStreamReceiver.start() + } + case networkInputRDS: NetworkInputRDS[_] => { + val networkStreamReceiver = new NetworkStreamReceiver( + networkInputRDS.inputName, + networkInputRDS.batchDuration, + 0, + ssc, + if (ssc.tempDir == null) null else ssc.tempDir.toString) + networkStreamReceiver.start() + } + case testInputRDS: TestInputRDS => { + testStreamReceiverNames += + ((testInputRDS.inputName, testInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds)) + } + } + }) + if (testStreamReceiverNames.size > 0) { + /*val testStreamCoordinator = new TestStreamCoordinator(testStreamReceiverNames.toArray)*/ + /*testStreamCoordinator.start()*/ + val actor = ssc.actorSystem.actorOf( + Props(new TestStreamCoordinator(testStreamReceiverNames.toArray)), + name = "TestStreamCoordinator") + } + } +} + diff --git a/streaming/src/main/scala/spark/stream/SenGeneratorForPerformanceTest.scala b/streaming/src/main/scala/spark/stream/SenGeneratorForPerformanceTest.scala new file mode 100644 index 0000000000..74fd54072f --- /dev/null +++ b/streaming/src/main/scala/spark/stream/SenGeneratorForPerformanceTest.scala @@ -0,0 +1,78 @@ +package spark.stream + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + +/*import akka.actor.Actor._*/ +/*import akka.actor.ActorRef*/ + + +object SenGeneratorForPerformanceTest { + + def printUsage () { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val inputManagerIP = args(0) + val inputManagerPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = { + if (args.length > 3) args(3).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + try { + /*val inputManager = remote.actorFor("InputReceiver-Sentences",*/ + /* inputManagerIP, inputManagerPort)*/ + val inputManager = select(Node(inputManagerIP, inputManagerPort), Symbol("InputReceiver-Sentences")) + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + println ("Sending " + sentencesPerSecond + " sentences per second to " + inputManagerIP + ":" + inputManagerPort) + var lastPrintTime = System.currentTimeMillis() + var count = 0 + + while (true) { + /*if (!inputManager.tryTell (lines (random.nextInt (lines.length))))*/ + /*throw new Exception ("disconnected")*/ +// inputManager ! lines (random.nextInt (lines.length)) + for (i <- 0 to sentencesPerSecond) inputManager ! lines (0) + println(System.currentTimeMillis / 1000 + " s") +/* count += 1 + + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + + Thread.sleep (sleepBetweenSentences.toLong) +*/ + val currentMs = System.currentTimeMillis / 1000; + Thread.sleep ((currentMs * 1000 + 1000) - System.currentTimeMillis) + } + } catch { + case e: Exception => + /*Thread.sleep (1000)*/ + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/stream/SenderReceiverTest.scala b/streaming/src/main/scala/spark/stream/SenderReceiverTest.scala new file mode 100644 index 0000000000..69879b621c --- /dev/null +++ b/streaming/src/main/scala/spark/stream/SenderReceiverTest.scala @@ -0,0 +1,63 @@ +package spark.stream +import java.net.{Socket, ServerSocket} +import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} + +object Receiver { + def main(args: Array[String]) { + val port = args(0).toInt + val lsocket = new ServerSocket(port) + println("Listening on port " + port ) + while(true) { + val socket = lsocket.accept() + (new Thread() { + override def run() { + val buffer = new Array[Byte](100000) + var count = 0 + val time = System.currentTimeMillis + try { + val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) + var loop = true + var string: String = null + while((string = is.readUTF) != null) { + count += 28 + } + } catch { + case e: Exception => e.printStackTrace + } + val timeTaken = System.currentTimeMillis - time + val tput = (count / 1024.0) / (timeTaken / 1000.0) + println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") + } + }).start() + } + } + +} + +object Sender { + + def main(args: Array[String]) { + try { + val host = args(0) + val port = args(1).toInt + val size = args(2).toInt + + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) + val bytes = byteStream.toByteArray() + println("Generated array of " + bytes.length + " bytes") + + /*val bytes = new Array[Byte](size)*/ + val socket = new Socket(host, port) + val os = socket.getOutputStream + os.write(bytes) + os.flush + socket.close() + + } catch { + case e: Exception => e.printStackTrace + } + } +} + diff --git a/streaming/src/main/scala/spark/stream/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/stream/SentenceFileGenerator.scala new file mode 100644 index 0000000000..9aa441d9bb --- /dev/null +++ b/streaming/src/main/scala/spark/stream/SentenceFileGenerator.scala @@ -0,0 +1,92 @@ +package spark.stream + +import spark._ + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.io.Source + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +object SentenceFileGenerator { + + def printUsage () { + println ("Usage: SentenceFileGenerator <# partitions> []") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val master = args(0) + val fs = new Path(args(1)).getFileSystem(new Configuration()) + val targetDirectory = new Path(args(1)).makeQualified(fs) + val numPartitions = args(2).toInt + val sentenceFile = args(3) + val sentencesPerSecond = { + if (args.length > 4) args(4).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n").toArray + source.close () + println("Read " + lines.length + " lines from file " + sentenceFile) + + val sentences = { + val buffer = ArrayBuffer[String]() + val random = new Random() + var i = 0 + while (i < sentencesPerSecond) { + buffer += lines(random.nextInt(lines.length)) + i += 1 + } + buffer.toArray + } + println("Generated " + sentences.length + " sentences") + + val sc = new SparkContext(master, "SentenceFileGenerator") + val sentencesRDD = sc.parallelize(sentences, numPartitions) + + val tempDirectory = new Path(targetDirectory, "_tmp") + + fs.mkdirs(targetDirectory) + fs.mkdirs(tempDirectory) + + var saveTimeMillis = System.currentTimeMillis + try { + while (true) { + val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) + val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) + println("Writing to file " + newDir) + sentencesRDD.saveAsTextFile(tmpNewDir.toString) + fs.rename(tmpNewDir, newDir) + saveTimeMillis += 1000 + val sleepTimeMillis = { + val currentTimeMillis = System.currentTimeMillis + if (saveTimeMillis < currentTimeMillis) { + 0 + } else { + saveTimeMillis - currentTimeMillis + } + } + println("Sleeping for " + sleepTimeMillis + " ms") + Thread.sleep(sleepTimeMillis) + } + } catch { + case e: Exception => + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/stream/SentenceGenerator.scala b/streaming/src/main/scala/spark/stream/SentenceGenerator.scala new file mode 100644 index 0000000000..ef66e66047 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/SentenceGenerator.scala @@ -0,0 +1,103 @@ +package spark.stream + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + + +object SentenceGenerator { + + def printUsage { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + + def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + + try { + var lastPrintTime = System.currentTimeMillis() + var count = 0 + while(true) { + streamReceiver ! lines(random.nextInt(lines.length)) + count += 1 + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + Thread.sleep(sleepBetweenSentences.toLong) + } + } catch { + case e: Exception => + } + } + + def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + try { + val numSentences = if (sentencesPerSecond <= 0) { + lines.length + } else { + sentencesPerSecond + } + var nextSendingTime = System.currentTimeMillis() + val pingInterval = if (System.getenv("INTERVAL") != null) { + System.getenv("INTERVAL").toInt + } else { + 2000 + } + while(true) { + (0 until numSentences).foreach(i => { + streamReceiver ! lines(i % lines.length) + }) + println ("Sent " + numSentences + " sentences") + nextSendingTime += pingInterval + val sleepTime = nextSendingTime - System.currentTimeMillis + if (sleepTime > 0) { + println ("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } + } + } catch { + case e: Exception => + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val generateRandomly = false + + val streamReceiverIP = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 + val sentenceInputName = if (args.length > 4) args(4) else "Sentences" + + println("Sending " + sentencesPerSecond + " sentences per second to " + + streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + val streamReceiver = select( + Node(streamReceiverIP, streamReceiverPort), + Symbol("NetworkStreamReceiver-" + sentenceInputName)) + if (generateRandomly) { + generateRandomSentences(lines, sentencesPerSecond, streamReceiver) + } else { + generateSameSentences(lines, sentencesPerSecond, streamReceiver) + } + } +} + + + diff --git a/streaming/src/main/scala/spark/stream/ShuffleTest.scala b/streaming/src/main/scala/spark/stream/ShuffleTest.scala new file mode 100644 index 0000000000..5ad56f6777 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/ShuffleTest.scala @@ -0,0 +1,22 @@ +package spark.stream +import spark.SparkContext +import SparkContext._ + +object ShuffleTest { + def main(args: Array[String]) { + + if (args.length < 1) { + println ("Usage: ShuffleTest ") + System.exit(1) + } + + val sc = new spark.SparkContext(args(0), "ShuffleTest") + val rdd = sc.parallelize(1 to 1000, 500).cache + + def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } + + time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } + System.exit(0) + } +} + diff --git a/streaming/src/main/scala/spark/stream/SimpleWordCount.scala b/streaming/src/main/scala/spark/stream/SimpleWordCount.scala new file mode 100644 index 0000000000..c53fe35f44 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/SimpleWordCount.scala @@ -0,0 +1,30 @@ +package spark.stream + +import SparkStreamContext._ + +import scala.util.Sorting + +object SimpleWordCount { + + def main (args: Array[String]) { + + if (args.length < 1) { + println ("Usage: SparkStreamContext []") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount") + if (args.length > 1) { + ssc.setTempDir(args(1)) + } + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) + /*sentences.print*/ + + val words = sentences.flatMap(_.split(" ")) + + val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1) + counts.print + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/stream/SimpleWordCount2.scala b/streaming/src/main/scala/spark/stream/SimpleWordCount2.scala new file mode 100644 index 0000000000..1a2c67cd4d --- /dev/null +++ b/streaming/src/main/scala/spark/stream/SimpleWordCount2.scala @@ -0,0 +1,51 @@ +package spark.stream + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import scala.util.Sorting + +object SimpleWordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SimpleWordCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + + val words = sentences.flatMap(_.split(" ")) + + val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10) + counts.foreachRDD(_.collect()) + /*words.foreachRDD(_.countByValue())*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/stream/SimpleWordCount2_Special.scala b/streaming/src/main/scala/spark/stream/SimpleWordCount2_Special.scala new file mode 100644 index 0000000000..9003a5dbb3 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/SimpleWordCount2_Special.scala @@ -0,0 +1,83 @@ +package spark.stream + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import scala.collection.JavaConversions.mapAsScalaMap +import scala.util.Sorting +import java.lang.{Long => JLong} + +object SimpleWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SimpleWordCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 400)).toArray + ) + + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } + + + /*val words = sentences.flatMap(_.split(" "))*/ + /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10)*/ + val counts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + counts.foreachRDD(_.collect()) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/stream/SparkStreamContext.scala b/streaming/src/main/scala/spark/stream/SparkStreamContext.scala new file mode 100644 index 0000000000..0e65196e46 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/SparkStreamContext.scala @@ -0,0 +1,105 @@ +package spark.stream + +import spark.SparkContext +import spark.SparkEnv +import spark.Utils +import spark.Logging + +import scala.collection.mutable.ArrayBuffer + +import java.net.InetSocketAddress +import java.io.IOException +import java.util.UUID + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +import akka.actor._ +import akka.actor.Actor +import akka.util.duration._ + +class SparkStreamContext ( + master: String, + frameworkName: String, + val sparkHome: String = null, + val jars: Seq[String] = Nil) + extends Logging { + + initLogging() + + val sc = new SparkContext(master, frameworkName, sparkHome, jars) + val env = SparkEnv.get + val actorSystem = env.actorSystem + + @transient val inputRDSs = new ArrayBuffer[InputRDS[_]]() + @transient val outputRDSs = new ArrayBuffer[RDS[_]]() + + var tempDirRoot: String = null + var tempDir: Path = null + + def readNetworkStream[T: ClassManifest]( + name: String, + addresses: Array[InetSocketAddress], + batchDuration: Time): RDS[T] = { + + val inputRDS = new NetworkInputRDS[T](name, addresses, batchDuration, this) + inputRDSs += inputRDS + inputRDS + } + + def readNetworkStream[T: ClassManifest]( + name: String, + addresses: Array[String], + batchDuration: Long): RDS[T] = { + + def stringToInetSocketAddress (str: String): InetSocketAddress = { + val parts = str.split(":") + if (parts.length != 2) { + throw new IllegalArgumentException ("Address format error") + } + new InetSocketAddress(parts(0), parts(1).toInt) + } + + readNetworkStream( + name, + addresses.map(stringToInetSocketAddress).toArray, + LongTime(batchDuration)) + } + + def readFileStream(name: String, directory: String): RDS[String] = { + val path = new Path(directory) + val fs = path.getFileSystem(new Configuration()) + val qualPath = path.makeQualified(fs) + val inputRDS = new FileInputRDS(name, qualPath.toString, this) + inputRDSs += inputRDS + inputRDS + } + + def readTestStream(name: String, batchDuration: Long): RDS[String] = { + val inputRDS = new TestInputRDS(name, LongTime(batchDuration), this) + inputRDSs += inputRDS + inputRDS + } + + def registerOutputStream (outputRDS: RDS[_]) { + outputRDSs += outputRDS + } + + def setTempDir(dir: String) { + tempDirRoot = dir + } + + def run () { + val ctxt = this + val actor = actorSystem.actorOf( + Props(new Scheduler(ctxt, inputRDSs.toArray, outputRDSs.toArray)), + name = "SparkStreamScheduler") + logInfo("Registered actor") + actorSystem.awaitTermination() + } +} + +object SparkStreamContext { + implicit def rdsToPairRdsFunctions [K: ClassManifest, V: ClassManifest] (rds: RDS[(K,V)]) = + new PairRDSFunctions (rds) +} diff --git a/streaming/src/main/scala/spark/stream/TestGenerator.scala b/streaming/src/main/scala/spark/stream/TestGenerator.scala new file mode 100644 index 0000000000..738ce17452 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TestGenerator.scala @@ -0,0 +1,107 @@ +package spark.stream + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + + +object TestGenerator { + + def printUsage { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + /* + def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + + try { + var lastPrintTime = System.currentTimeMillis() + var count = 0 + while(true) { + streamReceiver ! lines(random.nextInt(lines.length)) + count += 1 + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + Thread.sleep(sleepBetweenSentences.toLong) + } + } catch { + case e: Exception => + } + }*/ + + def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + try { + val numSentences = if (sentencesPerSecond <= 0) { + lines.length + } else { + sentencesPerSecond + } + val sentences = lines.take(numSentences).toArray + + var nextSendingTime = System.currentTimeMillis() + val sendAsArray = true + while(true) { + if (sendAsArray) { + println("Sending as array") + streamReceiver !? sentences + } else { + println("Sending individually") + sentences.foreach(sentence => { + streamReceiver !? sentence + }) + } + println ("Sent " + numSentences + " sentences in " + (System.currentTimeMillis - nextSendingTime) + " ms") + nextSendingTime += 1000 + val sleepTime = nextSendingTime - System.currentTimeMillis + if (sleepTime > 0) { + println ("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } + } + } catch { + case e: Exception => + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val generateRandomly = false + + val streamReceiverIP = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 + val sentenceInputName = if (args.length > 4) args(4) else "Sentences" + + println("Sending " + sentencesPerSecond + " sentences per second to " + + streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + val streamReceiver = select( + Node(streamReceiverIP, streamReceiverPort), + Symbol("NetworkStreamReceiver-" + sentenceInputName)) + if (generateRandomly) { + /*generateRandomSentences(lines, sentencesPerSecond, streamReceiver)*/ + } else { + generateSameSentences(lines, sentencesPerSecond, streamReceiver) + } + } +} + + + diff --git a/streaming/src/main/scala/spark/stream/TestGenerator2.scala b/streaming/src/main/scala/spark/stream/TestGenerator2.scala new file mode 100644 index 0000000000..ceb4730e72 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TestGenerator2.scala @@ -0,0 +1,119 @@ +package spark.stream + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.io.{DataOutputStream, ByteArrayOutputStream, DataInputStream} +import java.net.Socket + +object TestGenerator2 { + + def printUsage { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + + def sendSentences(streamReceiverHost: String, streamReceiverPort: Int, numSentences: Int, bytes: Array[Byte], intervalTime: Long){ + try { + println("Connecting to " + streamReceiverHost + ":" + streamReceiverPort) + val socket = new Socket(streamReceiverHost, streamReceiverPort) + + println("Sending " + numSentences+ " sentences / " + (bytes.length / 1024.0 / 1024.0) + " MB per " + intervalTime + " ms to " + streamReceiverHost + ":" + streamReceiverPort ) + val currentTime = System.currentTimeMillis + var targetTime = (currentTime / intervalTime + 1).toLong * intervalTime + Thread.sleep(targetTime - currentTime) + + while(true) { + val startTime = System.currentTimeMillis() + println("Sending at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") + val socketOutputStream = socket.getOutputStream + val parts = 10 + (0 until parts).foreach(i => { + val partStartTime = System.currentTimeMillis + + val offset = (i * bytes.length / parts).toInt + val len = math.min(((i + 1) * bytes.length / parts).toInt - offset, bytes.length) + socketOutputStream.write(bytes, offset, len) + socketOutputStream.flush() + val partFinishTime = System.currentTimeMillis + println("Sending part " + i + " of " + len + " bytes took " + (partFinishTime - partStartTime) + " ms") + val sleepTime = math.max(0, 1000 / parts - (partFinishTime - partStartTime) - 1) + Thread.sleep(sleepTime) + }) + + socketOutputStream.flush() + /*val socketInputStream = new DataInputStream(socket.getInputStream)*/ + /*val reply = socketInputStream.readUTF()*/ + val finishTime = System.currentTimeMillis() + println ("Sent " + bytes.length + " bytes in " + (finishTime - startTime) + " ms for interval [" + targetTime + ", " + (targetTime + intervalTime) + "]") + /*println("Received = " + reply)*/ + targetTime = targetTime + intervalTime + val sleepTime = (targetTime - finishTime) + 10 + if (sleepTime > 0) { + println("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } else { + println("############################") + println("###### Skipping sleep ######") + println("############################") + } + } + } catch { + case e: Exception => println(e) + } + println("Stopped sending") + } + + def main(args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val streamReceiverHost = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val intervalTime = args(3).toLong + val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 + + println("Reading the file " + sentenceFile) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close() + + val numSentences = if (sentencesPerInterval <= 0) { + lines.length + } else { + sentencesPerInterval + } + + println("Generating sentences") + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take(numSentences).toArray + } else { + (0 until numSentences).map(i => lines(i % lines.length)).toArray + } + + println("Converting to byte array") + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + /*stringDataStream.writeInt(sentences.size)*/ + sentences.foreach(stringDataStream.writeUTF) + val bytes = byteStream.toByteArray() + stringDataStream.close() + println("Generated array of " + bytes.length + " bytes") + + /*while(true) { */ + sendSentences(streamReceiverHost, streamReceiverPort, numSentences, bytes, intervalTime) + /*println("Sleeping for 5 seconds")*/ + /*Thread.sleep(5000)*/ + /*System.gc()*/ + /*}*/ + } +} + + + diff --git a/streaming/src/main/scala/spark/stream/TestGenerator4.scala b/streaming/src/main/scala/spark/stream/TestGenerator4.scala new file mode 100644 index 0000000000..edeb969d7c --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TestGenerator4.scala @@ -0,0 +1,244 @@ +package spark.stream + +import spark.Logging + +import scala.util.Random +import scala.io.Source +import scala.collection.mutable.{ArrayBuffer, Queue} + +import java.net._ +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ + +import it.unimi.dsi.fastutil.io._ + +class TestGenerator4(targetHost: String, targetPort: Int, sentenceFile: String, intervalDuration: Long, sentencesPerInterval: Int) +extends Logging { + + class SendingConnectionHandler(host: String, port: Int, generator: TestGenerator4) + extends ConnectionHandler(host, port, true) { + + val buffers = new ArrayBuffer[ByteBuffer] + val newBuffers = new Queue[ByteBuffer] + var activeKey: SelectionKey = null + + def send(buffer: ByteBuffer) { + logDebug("Sending: " + buffer) + newBuffers.synchronized { + newBuffers.enqueue(buffer) + } + selector.wakeup() + buffer.synchronized { + buffer.wait() + } + } + + override def ready(key: SelectionKey) { + logDebug("Ready") + activeKey = key + val channel = key.channel.asInstanceOf[SocketChannel] + channel.register(selector, SelectionKey.OP_WRITE) + generator.startSending() + } + + override def preSelect() { + newBuffers.synchronized { + while(!newBuffers.isEmpty) { + val buffer = newBuffers.dequeue + buffers += buffer + logDebug("Added: " + buffer) + changeInterest(activeKey, SelectionKey.OP_WRITE) + } + } + } + + override def write(key: SelectionKey) { + try { + /*while(true) {*/ + val channel = key.channel.asInstanceOf[SocketChannel] + if (buffers.size > 0) { + val buffer = buffers(0) + val newBuffer = buffer.slice() + newBuffer.limit(math.min(newBuffer.remaining, 32768)) + val bytesWritten = channel.write(newBuffer) + buffer.position(buffer.position + bytesWritten) + if (bytesWritten == 0) return + if (buffer.remaining == 0) { + buffers -= buffer + buffer.synchronized { + buffer.notify() + } + } + /*changeInterest(key, SelectionKey.OP_WRITE)*/ + } else { + changeInterest(key, 0) + } + /*}*/ + } catch { + case e: IOException => { + if (e.toString.contains("pipe") || e.toString.contains("reset")) { + logError("Connection broken") + } else { + logError("Connection error", e) + } + close(key) + } + } + } + + override def close(key: SelectionKey) { + buffers.clear() + super.close(key) + } + } + + initLogging() + + val connectionHandler = new SendingConnectionHandler(targetHost, targetPort, this) + var sendingThread: Thread = null + var sendCount = 0 + val sendBatches = 5 + + def run() { + logInfo("Connection handler started") + connectionHandler.start() + connectionHandler.join() + if (sendingThread != null && !sendingThread.isInterrupted) { + sendingThread.interrupt + } + logInfo("Connection handler stopped") + } + + def startSending() { + sendingThread = new Thread() { + override def run() { + logInfo("STARTING TO SEND") + sendSentences() + logInfo("SENDING STOPPED AFTER " + sendCount) + connectionHandler.interrupt() + } + } + sendingThread.start() + } + + def stopSending() { + sendingThread.interrupt() + } + + def sendSentences() { + logInfo("Reading the file " + sentenceFile) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close() + + val numSentences = if (sentencesPerInterval <= 0) { + lines.length + } else { + sentencesPerInterval + } + + logInfo("Generating sentence buffer") + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take(numSentences).toArray + } else { + (0 until numSentences).map(i => lines(i % lines.length)).toArray + } + + /* + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take((numSentences / sendBatches).toInt).toArray + } else { + (0 until (numSentences/sendBatches)).map(i => lines(i % lines.length)).toArray + }*/ + + + val serializer = new spark.KryoSerializer().newInstance() + val byteStream = new FastByteArrayOutputStream(100 * 1024 * 1024) + serializer.serializeStream(byteStream).writeAll(sentences.toIterator.asInstanceOf[Iterator[Any]]).close() + byteStream.trim() + val sentenceBuffer = ByteBuffer.wrap(byteStream.array) + + logInfo("Sending " + numSentences+ " sentences / " + sentenceBuffer.limit + " bytes per " + intervalDuration + " ms to " + targetHost + ":" + targetPort ) + val currentTime = System.currentTimeMillis + var targetTime = (currentTime / intervalDuration + 1).toLong * intervalDuration + Thread.sleep(targetTime - currentTime) + + val totalBytes = sentenceBuffer.limit + + while(true) { + val batchesInCurrentInterval = sendBatches // if (sendCount < 10) 1 else sendBatches + + val startTime = System.currentTimeMillis() + logDebug("Sending # " + sendCount + " at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") + + (0 until batchesInCurrentInterval).foreach(i => { + try { + val position = (i * totalBytes / sendBatches).toInt + val limit = if (i == sendBatches - 1) { + totalBytes + } else { + ((i + 1) * totalBytes / sendBatches).toInt - 1 + } + + val partStartTime = System.currentTimeMillis + sentenceBuffer.limit(limit) + connectionHandler.send(sentenceBuffer) + val partFinishTime = System.currentTimeMillis + val sleepTime = math.max(0, intervalDuration / sendBatches - (partFinishTime - partStartTime) - 1) + Thread.sleep(sleepTime) + + } catch { + case ie: InterruptedException => return + case e: Exception => e.printStackTrace() + } + }) + sentenceBuffer.rewind() + + val finishTime = System.currentTimeMillis() + /*logInfo ("Sent " + sentenceBuffer.limit + " bytes in " + (finishTime - startTime) + " ms")*/ + targetTime = targetTime + intervalDuration //+ (if (sendCount < 3) 1000 else 0) + + val sleepTime = (targetTime - finishTime) + 20 + if (sleepTime > 0) { + logInfo("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } else { + logInfo("###### Skipping sleep ######") + } + if (Thread.currentThread.isInterrupted) { + return + } + sendCount += 1 + } + } +} + +object TestGenerator4 { + def printUsage { + println("Usage: TestGenerator4 []") + System.exit(0) + } + + def main(args: Array[String]) { + println("GENERATOR STARTED") + if (args.length < 4) { + printUsage + } + + + val streamReceiverHost = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val intervalDuration = args(3).toLong + val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 + + while(true) { + val generator = new TestGenerator4(streamReceiverHost, streamReceiverPort, sentenceFile, intervalDuration, sentencesPerInterval) + generator.run() + Thread.sleep(2000) + } + println("GENERATOR STOPPED") + } +} diff --git a/streaming/src/main/scala/spark/stream/TestInputBlockTracker.scala b/streaming/src/main/scala/spark/stream/TestInputBlockTracker.scala new file mode 100644 index 0000000000..da3b964407 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TestInputBlockTracker.scala @@ -0,0 +1,42 @@ +package spark.stream +import spark.Logging +import scala.collection.mutable.{ArrayBuffer, HashMap} + +object TestInputBlockTracker extends Logging { + initLogging() + val allBlockIds = new HashMap[Time, ArrayBuffer[String]]() + + def addBlocks(intervalEndTime: Time, reference: AnyRef) { + allBlockIds.getOrElseUpdate(intervalEndTime, new ArrayBuffer[String]()) ++= reference.asInstanceOf[Array[String]] + } + + def setEndTime(intervalEndTime: Time) { + try { + val endTime = System.currentTimeMillis + allBlockIds.get(intervalEndTime) match { + case Some(blockIds) => { + val numBlocks = blockIds.size + var totalDelay = 0d + blockIds.foreach(blockId => { + val inputTime = getInputTime(blockId) + val delay = (endTime - inputTime) / 1000.0 + totalDelay += delay + logInfo("End-to-end delay for block " + blockId + " is " + delay + " s") + }) + logInfo("Average end-to-end delay for time " + intervalEndTime + " is " + (totalDelay / numBlocks) + " s") + allBlockIds -= intervalEndTime + } + case None => throw new Exception("Unexpected") + } + } catch { + case e: Exception => logError(e.toString) + } + } + + def getInputTime(blockId: String): Long = { + val parts = blockId.split("-") + /*logInfo(blockId + " -> " + parts(4)) */ + parts(4).toLong + } +} + diff --git a/streaming/src/main/scala/spark/stream/TestStreamCoordinator.scala b/streaming/src/main/scala/spark/stream/TestStreamCoordinator.scala new file mode 100644 index 0000000000..add166fbd9 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TestStreamCoordinator.scala @@ -0,0 +1,38 @@ +package spark.stream + +import spark.Logging + +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ + +sealed trait TestStreamCoordinatorMessage +case class GetStreamDetails extends TestStreamCoordinatorMessage +case class GotStreamDetails(name: String, duration: Long) extends TestStreamCoordinatorMessage +case class TestStarted extends TestStreamCoordinatorMessage + +class TestStreamCoordinator(streamDetails: Array[(String, Long)]) extends Actor with Logging { + + var index = 0 + + initLogging() + + logInfo("Created") + + def receive = { + case TestStarted => { + sender ! "OK" + } + + case GetStreamDetails => { + val streamDetail = if (index >= streamDetails.length) null else streamDetails(index) + sender ! GotStreamDetails(streamDetail._1, streamDetail._2) + index += 1 + if (streamDetail != null) { + logInfo("Allocated " + streamDetail._1 + " (" + index + "/" + streamDetails.length + ")" ) + } + } + } + +} + diff --git a/streaming/src/main/scala/spark/stream/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/stream/TestStreamReceiver3.scala new file mode 100644 index 0000000000..9cc342040b --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TestStreamReceiver3.scala @@ -0,0 +1,420 @@ +package spark.stream + +import spark._ +import spark.storage._ +import spark.util.AkkaUtils + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} + +import akka.actor._ +import akka.actor.Actor +import akka.dispatch._ +import akka.pattern.ask +import akka.util.duration._ + +import java.io.DataInputStream +import java.io.BufferedInputStream +import java.net.Socket +import java.net.ServerSocket +import java.util.LinkedHashMap + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +import spark.Utils + + +class TestStreamReceiver3(actorSystem: ActorSystem, blockManager: BlockManager) +extends Thread with Logging { + + + class DataHandler( + inputName: String, + longIntervalDuration: LongTime, + shortIntervalDuration: LongTime, + blockManager: BlockManager + ) + extends Logging { + + class Block(var id: String, var shortInterval: Interval) { + val data = ArrayBuffer[String]() + var pushed = false + def longInterval = getLongInterval(shortInterval) + def empty() = (data.size == 0) + def += (str: String) = (data += str) + override def toString() = "Block " + id + } + + class Bucket(val longInterval: Interval) { + val blocks = new ArrayBuffer[Block]() + var filled = false + def += (block: Block) = blocks += block + def empty() = (blocks.size == 0) + def ready() = (filled && !blocks.exists(! _.pushed)) + def blockIds() = blocks.map(_.id).toArray + override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" + } + + initLogging() + + val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds + val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + + var currentBlock: Block = null + var currentBucket: Bucket = null + + val blocksForPushing = new Queue[Block]() + val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] + + val blockUpdatingThread = new Thread() { override def run() { keepUpdatingCurrentBlock() } } + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + def start() { + blockUpdatingThread.start() + blockPushingThread.start() + } + + def += (data: String) = addData(data) + + def addData(data: String) { + if (currentBlock == null) { + updateCurrentBlock() + } + currentBlock.synchronized { + currentBlock += data + } + } + + def getShortInterval(time: Time): Interval = { + val intervalBegin = time.floor(shortIntervalDuration) + Interval(intervalBegin, intervalBegin + shortIntervalDuration) + } + + def getLongInterval(shortInterval: Interval): Interval = { + val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) + Interval(intervalBegin, intervalBegin + longIntervalDuration) + } + + def updateCurrentBlock() { + /*logInfo("Updating current block")*/ + val currentTime: LongTime = LongTime(System.currentTimeMillis) + val shortInterval = getShortInterval(currentTime) + val longInterval = getLongInterval(shortInterval) + + def createBlock(reuseCurrentBlock: Boolean = false) { + val newBlockId = inputName + "-" + longInterval.toFormattedString + "-" + currentBucket.blocks.size + if (!reuseCurrentBlock) { + val newBlock = new Block(newBlockId, shortInterval) + /*logInfo("Created " + currentBlock)*/ + currentBlock = newBlock + } else { + currentBlock.shortInterval = shortInterval + currentBlock.id = newBlockId + } + } + + def createBucket() { + val newBucket = new Bucket(longInterval) + buckets += ((longInterval, newBucket)) + currentBucket = newBucket + /*logInfo("Created " + currentBucket + ", " + buckets.size + " buckets")*/ + } + + if (currentBlock == null || currentBucket == null) { + createBucket() + currentBucket.synchronized { + createBlock() + } + return + } + + currentBlock.synchronized { + var reuseCurrentBlock = false + + if (shortInterval != currentBlock.shortInterval) { + if (!currentBlock.empty) { + blocksForPushing.synchronized { + blocksForPushing += currentBlock + blocksForPushing.notifyAll() + } + } + + currentBucket.synchronized { + if (currentBlock.empty) { + reuseCurrentBlock = true + } else { + currentBucket += currentBlock + } + + if (longInterval != currentBucket.longInterval) { + currentBucket.filled = true + if (currentBucket.ready) { + currentBucket.notifyAll() + } + createBucket() + } + } + + createBlock(reuseCurrentBlock) + } + } + } + + def pushBlock(block: Block) { + try{ + if (blockManager != null) { + logInfo("Pushing block") + val startTime = System.currentTimeMillis + + val bytes = blockManager.dataSerialize(block.data.toIterator) + val finishTime = System.currentTimeMillis + logInfo(block + " serialization delay is " + (finishTime - startTime) / 1000.0 + " s") + + blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_2) + /*blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ + /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY_DESER)*/ + /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY)*/ + val finishTime1 = System.currentTimeMillis + logInfo(block + " put delay is " + (finishTime1 - startTime) / 1000.0 + " s") + } else { + logWarning(block + " not put as block manager is null") + } + } catch { + case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) + } + } + + def getBucket(longInterval: Interval): Option[Bucket] = { + buckets.get(longInterval) + } + + def clearBucket(longInterval: Interval) { + buckets.remove(longInterval) + } + + def keepUpdatingCurrentBlock() { + logInfo("Thread to update current block started") + while(true) { + updateCurrentBlock() + val currentTimeMillis = System.currentTimeMillis + val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * + shortIntervalDurationMillis - currentTimeMillis + 1 + Thread.sleep(sleepTimeMillis) + } + } + + def keepPushingBlocks() { + var loop = true + logInfo("Thread to push blocks started") + while(loop) { + val block = blocksForPushing.synchronized { + if (blocksForPushing.size == 0) { + blocksForPushing.wait() + } + blocksForPushing.dequeue + } + pushBlock(block) + block.pushed = true + block.data.clear() + + val bucket = buckets(block.longInterval) + bucket.synchronized { + if (bucket.ready) { + bucket.notifyAll() + } + } + } + } + } + + + class ConnectionListener(port: Int, dataHandler: DataHandler) + extends Thread with Logging { + initLogging() + override def run { + try { + val listener = new ServerSocket(port) + logInfo("Listening on port " + port) + while (true) { + new ConnectionHandler(listener.accept(), dataHandler).start(); + } + listener.close() + } catch { + case e: Exception => logError("", e); + } + } + } + + class ConnectionHandler(socket: Socket, dataHandler: DataHandler) extends Thread with Logging { + initLogging() + override def run { + logInfo("New connection from " + socket.getInetAddress() + ":" + socket.getPort) + val bytes = new Array[Byte](100 * 1024 * 1024) + try { + + val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, 1024 * 1024)) + /*val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream))*/ + var str: String = null + str = inputStream.readUTF + while(str != null) { + dataHandler += str + str = inputStream.readUTF() + } + + /* + var loop = true + while(loop) { + val numRead = inputStream.read(bytes) + if (numRead < 0) { + loop = false + } + inbox += ((LongTime(SystemTime.currentTimeMillis), "test")) + }*/ + + inputStream.close() + } catch { + case e => logError("Error receiving data", e) + } + socket.close() + } + } + + initLogging() + + val masterHost = System.getProperty("spark.master.host") + val masterPort = System.getProperty("spark.master.port").toInt + + val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) + val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") + val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") + + logInfo("Getting stream details from master " + masterHost + ":" + masterPort) + + val timeout = 50 millis + + var started = false + while (!started) { + askActor[String](testStreamCoordinator, TestStarted) match { + case Some(str) => { + started = true + logInfo("TestStreamCoordinator started") + } + case None => { + logInfo("TestStreamCoordinator not started yet") + Thread.sleep(200) + } + } + } + + val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { + case Some(details) => details + case None => throw new Exception("Could not get stream details") + } + logInfo("Stream details received: " + streamDetails) + + val inputName = streamDetails.name + val intervalDurationMillis = streamDetails.duration + val intervalDuration = LongTime(intervalDurationMillis) + + val dataHandler = new DataHandler( + inputName, + intervalDuration, + LongTime(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), + blockManager) + + val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) + + // Send a message to an actor and return an option with its reply, or None if this times out + def askActor[T](actor: ActorRef, message: Any): Option[T] = { + try { + val future = actor.ask(message)(timeout) + return Some(Await.result(future, timeout).asInstanceOf[T]) + } catch { + case e: Exception => + logInfo("Error communicating with " + actor, e) + return None + } + } + + override def run() { + connListener.start() + dataHandler.start() + + var interval = Interval.currentInterval(intervalDuration) + var dataStarted = false + + while(true) { + waitFor(interval.endTime) + logInfo("Woken up at " + System.currentTimeMillis + " for " + interval) + dataHandler.getBucket(interval) match { + case Some(bucket) => { + logInfo("Found " + bucket + " for " + interval) + bucket.synchronized { + if (!bucket.ready) { + logInfo("Waiting for " + bucket) + bucket.wait() + logInfo("Wait over for " + bucket) + } + if (dataStarted || !bucket.empty) { + logInfo("Notifying " + bucket) + notifyScheduler(interval, bucket.blockIds) + dataStarted = true + } + bucket.blocks.clear() + dataHandler.clearBucket(interval) + } + } + case None => { + logInfo("Found none for " + interval) + if (dataStarted) { + logInfo("Notifying none") + notifyScheduler(interval, Array[String]()) + } + } + } + interval = interval.next + } + } + + def waitFor(time: Time) { + val currentTimeMillis = System.currentTimeMillis + val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + if (currentTimeMillis < targetTimeMillis) { + val sleepTime = (targetTimeMillis - currentTimeMillis) + Thread.sleep(sleepTime + 1) + } + } + + def notifyScheduler(interval: Interval, blockIds: Array[String]) { + try { + sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) + val time = interval.endTime.asInstanceOf[LongTime] + val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 + logInfo("Pushing delay for " + time + " is " + delay + " s") + } catch { + case _ => logError("Exception notifying scheduler at interval " + interval) + } + } +} + +object TestStreamReceiver3 { + + val PORT = 9999 + val SHORT_INTERVAL_MILLIS = 100 + + def main(args: Array[String]) { + System.setProperty("spark.master.host", Utils.localHostName) + System.setProperty("spark.master.port", "7078") + val details = Array(("Sentences", 2000L)) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) + actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") + new TestStreamReceiver3(actorSystem, null).start() + } +} + + + diff --git a/streaming/src/main/scala/spark/stream/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/stream/TestStreamReceiver4.scala new file mode 100644 index 0000000000..e7bef75391 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TestStreamReceiver4.scala @@ -0,0 +1,373 @@ +package spark.stream + +import spark._ +import spark.storage._ +import spark.util.AkkaUtils + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} + +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ +import java.util.concurrent.Executors + +import akka.actor._ +import akka.actor.Actor +import akka.dispatch._ +import akka.pattern.ask +import akka.util.duration._ + +class TestStreamReceiver4(actorSystem: ActorSystem, blockManager: BlockManager) +extends Thread with Logging { + + class DataHandler( + inputName: String, + longIntervalDuration: LongTime, + shortIntervalDuration: LongTime, + blockManager: BlockManager + ) + extends Logging { + + class Block(val id: String, val shortInterval: Interval, val buffer: ByteBuffer) { + var pushed = false + def longInterval = getLongInterval(shortInterval) + override def toString() = "Block " + id + } + + class Bucket(val longInterval: Interval) { + val blocks = new ArrayBuffer[Block]() + var filled = false + def += (block: Block) = blocks += block + def empty() = (blocks.size == 0) + def ready() = (filled && !blocks.exists(! _.pushed)) + def blockIds() = blocks.map(_.id).toArray + override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" + } + + initLogging() + + val syncOnLastShortInterval = true + + val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds + val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + + val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) + var currentShortInterval = Interval.currentInterval(shortIntervalDuration) + + val blocksForPushing = new Queue[Block]() + val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] + + val bufferProcessingThread = new Thread() { override def run() { keepProcessingBuffers() } } + val blockPushingExecutor = Executors.newFixedThreadPool(5) + + + def start() { + buffer.clear() + if (buffer.remaining == 0) { + throw new Exception("Buffer initialization error") + } + bufferProcessingThread.start() + } + + def readDataToBuffer(func: ByteBuffer => Int): Int = { + buffer.synchronized { + if (buffer.remaining == 0) { + logInfo("Received first data for interval " + currentShortInterval) + } + func(buffer) + } + } + + def getLongInterval(shortInterval: Interval): Interval = { + val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) + Interval(intervalBegin, intervalBegin + longIntervalDuration) + } + + def processBuffer() { + + def readInt(buffer: ByteBuffer): Int = { + var offset = 0 + var result = 0 + while (offset < 32) { + val b = buffer.get() + result |= ((b & 0x7F) << offset) + if ((b & 0x80) == 0) { + return result + } + offset += 7 + } + throw new Exception("Malformed zigzag-encoded integer") + } + + val currentLongInterval = getLongInterval(currentShortInterval) + val startTime = System.currentTimeMillis + val newBuffer: ByteBuffer = buffer.synchronized { + buffer.flip() + if (buffer.remaining == 0) { + buffer.clear() + null + } else { + logDebug("Processing interval " + currentShortInterval + " with delay of " + (System.currentTimeMillis - startTime) + " ms") + val startTime1 = System.currentTimeMillis + var loop = true + var count = 0 + while(loop) { + buffer.mark() + try { + val len = readInt(buffer) + buffer.position(buffer.position + len) + count += 1 + } catch { + case e: Exception => { + buffer.reset() + loop = false + } + } + } + val bytesToCopy = buffer.position + val newBuf = ByteBuffer.allocate(bytesToCopy) + buffer.position(0) + newBuf.put(buffer.slice().limit(bytesToCopy).asInstanceOf[ByteBuffer]) + newBuf.flip() + buffer.position(bytesToCopy) + buffer.compact() + newBuf + } + } + + if (newBuffer != null) { + val bucket = buckets.getOrElseUpdate(currentLongInterval, new Bucket(currentLongInterval)) + bucket.synchronized { + val newBlockId = inputName + "-" + currentLongInterval.toFormattedString + "-" + currentShortInterval.toFormattedString + val newBlock = new Block(newBlockId, currentShortInterval, newBuffer) + if (syncOnLastShortInterval) { + bucket += newBlock + } + logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.asInstanceOf[LongTime].milliseconds) / 1000.0 + " s" ) + blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) + } + } + + val newShortInterval = Interval.currentInterval(shortIntervalDuration) + val newLongInterval = getLongInterval(newShortInterval) + + if (newLongInterval != currentLongInterval) { + buckets.get(currentLongInterval) match { + case Some(bucket) => { + bucket.synchronized { + bucket.filled = true + if (bucket.ready) { + bucket.notifyAll() + } + } + } + case None => + } + buckets += ((newLongInterval, new Bucket(newLongInterval))) + } + + currentShortInterval = newShortInterval + } + + def pushBlock(block: Block) { + try{ + if (blockManager != null) { + val startTime = System.currentTimeMillis + logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.asInstanceOf[LongTime].milliseconds) + " ms") + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ + blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ + val finishTime = System.currentTimeMillis + logInfo(block + " put delay is " + (finishTime - startTime) + " ms") + } else { + logWarning(block + " not put as block manager is null") + } + } catch { + case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) + } + } + + def getBucket(longInterval: Interval): Option[Bucket] = { + buckets.get(longInterval) + } + + def clearBucket(longInterval: Interval) { + buckets.remove(longInterval) + } + + def keepProcessingBuffers() { + logInfo("Thread to process buffers started") + while(true) { + processBuffer() + val currentTimeMillis = System.currentTimeMillis + val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * + shortIntervalDurationMillis - currentTimeMillis + 1 + Thread.sleep(sleepTimeMillis) + } + } + + def pushAndNotifyBlock(block: Block) { + pushBlock(block) + block.pushed = true + val bucket = if (syncOnLastShortInterval) { + buckets(block.longInterval) + } else { + var longInterval = block.longInterval + while(!buckets.contains(longInterval)) { + logWarning("Skipping bucket of " + longInterval + " for " + block) + longInterval = longInterval.next + } + val chosenBucket = buckets(longInterval) + logDebug("Choosing bucket of " + longInterval + " for " + block) + chosenBucket += block + chosenBucket + } + + bucket.synchronized { + if (bucket.ready) { + bucket.notifyAll() + } + } + + } + } + + + class ReceivingConnectionHandler(host: String, port: Int, dataHandler: DataHandler) + extends ConnectionHandler(host, port, false) { + + override def ready(key: SelectionKey) { + changeInterest(key, SelectionKey.OP_READ) + } + + override def read(key: SelectionKey) { + try { + val channel = key.channel.asInstanceOf[SocketChannel] + val bytesRead = dataHandler.readDataToBuffer(channel.read) + if (bytesRead < 0) { + close(key) + } + } catch { + case e: IOException => { + logError("Error reading", e) + close(key) + } + } + } + } + + initLogging() + + val masterHost = System.getProperty("spark.master.host", "localhost") + val masterPort = System.getProperty("spark.master.port", "7078").toInt + + val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) + val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") + val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") + + logInfo("Getting stream details from master " + masterHost + ":" + masterPort) + + val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { + case Some(details) => details + case None => throw new Exception("Could not get stream details") + } + logInfo("Stream details received: " + streamDetails) + + val inputName = streamDetails.name + val intervalDurationMillis = streamDetails.duration + val intervalDuration = Milliseconds(intervalDurationMillis) + val shortIntervalDuration = Milliseconds(System.getProperty("spark.stream.shortinterval", "500").toInt) + + val dataHandler = new DataHandler(inputName, intervalDuration, shortIntervalDuration, blockManager) + val connectionHandler = new ReceivingConnectionHandler("localhost", 9999, dataHandler) + + val timeout = 100 millis + + // Send a message to an actor and return an option with its reply, or None if this times out + def askActor[T](actor: ActorRef, message: Any): Option[T] = { + try { + val future = actor.ask(message)(timeout) + return Some(Await.result(future, timeout).asInstanceOf[T]) + } catch { + case e: Exception => + logInfo("Error communicating with " + actor, e) + return None + } + } + + override def run() { + connectionHandler.start() + dataHandler.start() + + var interval = Interval.currentInterval(intervalDuration) + var dataStarted = false + + + while(true) { + waitFor(interval.endTime) + /*logInfo("Woken up at " + System.currentTimeMillis + " for " + interval)*/ + dataHandler.getBucket(interval) match { + case Some(bucket) => { + logDebug("Found " + bucket + " for " + interval) + bucket.synchronized { + if (!bucket.ready) { + logDebug("Waiting for " + bucket) + bucket.wait() + logDebug("Wait over for " + bucket) + } + if (dataStarted || !bucket.empty) { + logDebug("Notifying " + bucket) + notifyScheduler(interval, bucket.blockIds) + dataStarted = true + } + bucket.blocks.clear() + dataHandler.clearBucket(interval) + } + } + case None => { + logDebug("Found none for " + interval) + if (dataStarted) { + logDebug("Notifying none") + notifyScheduler(interval, Array[String]()) + } + } + } + interval = interval.next + } + } + + def waitFor(time: Time) { + val currentTimeMillis = System.currentTimeMillis + val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + if (currentTimeMillis < targetTimeMillis) { + val sleepTime = (targetTimeMillis - currentTimeMillis) + Thread.sleep(sleepTime + 1) + } + } + + def notifyScheduler(interval: Interval, blockIds: Array[String]) { + try { + sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) + val time = interval.endTime.asInstanceOf[LongTime] + val delay = (System.currentTimeMillis - time.milliseconds) + logInfo("Notification delay for " + time + " is " + delay + " ms") + } catch { + case e: Exception => logError("Exception notifying scheduler at interval " + interval + ": " + e) + } + } +} + + +object TestStreamReceiver4 { + def main(args: Array[String]) { + val details = Array(("Sentences", 2000L)) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) + actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") + new TestStreamReceiver4(actorSystem, null).start() + } +} diff --git a/streaming/src/main/scala/spark/stream/Time.scala b/streaming/src/main/scala/spark/stream/Time.scala new file mode 100644 index 0000000000..25369dfee5 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/Time.scala @@ -0,0 +1,85 @@ +package spark.stream + +abstract case class Time { + + // basic operations that must be overridden + def copy(): Time + def zero: Time + def < (that: Time): Boolean + def += (that: Time): Time + def -= (that: Time): Time + def floor(that: Time): Time + def isMultipleOf(that: Time): Boolean + + // derived operations composed of basic operations + def + (that: Time) = this.copy() += that + def - (that: Time) = this.copy() -= that + def * (times: Int) = { + var count = 0 + var result = this.copy() + while (count < times) { + result += this + count += 1 + } + result + } + def <= (that: Time) = (this < that || this == that) + def > (that: Time) = !(this <= that) + def >= (that: Time) = !(this < that) + def isZero = (this == zero) + def toFormattedString = toString +} + +object Time { + def Milliseconds(milliseconds: Long) = LongTime(milliseconds) + + def zero = LongTime(0) +} + +case class LongTime(var milliseconds: Long) extends Time { + + override def copy() = LongTime(this.milliseconds) + + override def zero = LongTime(0) + + override def < (that: Time): Boolean = + (this.milliseconds < that.asInstanceOf[LongTime].milliseconds) + + override def += (that: Time): Time = { + this.milliseconds += that.asInstanceOf[LongTime].milliseconds + this + } + + override def -= (that: Time): Time = { + this.milliseconds -= that.asInstanceOf[LongTime].milliseconds + this + } + + override def floor(that: Time): Time = { + val t = that.asInstanceOf[LongTime].milliseconds + val m = this.milliseconds / t + LongTime(m.toLong * t) + } + + override def isMultipleOf(that: Time): Boolean = + (this.milliseconds % that.asInstanceOf[LongTime].milliseconds == 0) + + override def isZero = (this.milliseconds == 0) + + override def toString = (milliseconds.toString + "ms") + + override def toFormattedString = milliseconds.toString +} + +object Milliseconds { + def apply(milliseconds: Long) = LongTime(milliseconds) +} + +object Seconds { + def apply(seconds: Long) = LongTime(seconds * 1000) +} + +object Minutes { + def apply(minutes: Long) = LongTime(minutes * 60000) +} + diff --git a/streaming/src/main/scala/spark/stream/TopContentCount.scala b/streaming/src/main/scala/spark/stream/TopContentCount.scala new file mode 100644 index 0000000000..a8cca4e793 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TopContentCount.scala @@ -0,0 +1,97 @@ +package spark.stream + +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting + +object TopContentCount { + + case class Event(val country: String, val content: String) + + object Event { + def create(string: String): Event = { + val parts = string.split(":") + new Event(parts(0), parts(1)) + } + } + + def main(args: Array[String]) { + + if (args.length < 2) { + println ("Usage: GrepCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "TopContentCount") + val sc = ssc.sc + val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) + sc.runJob(dummy, (_: Iterator[Int]) => {}) + + + val numEventStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + val eventStrings = new UnifiedRDS( + (1 to numEventStreams).map(i => ssc.readTestStream("Events-" + i, 1000)).toArray + ) + + def parse(string: String) = { + val parts = string.split(":") + (parts(0), parts(1)) + } + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val events = eventStrings.map(x => parse(x)) + /*events.print*/ + + val parallelism = 8 + val counts_per_content_per_country = events + .map(x => (x, 1)) + .reduceByKey(_ + _) + /*.reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), parallelism)*/ + /*counts_per_content_per_country.print*/ + + /* + counts_per_content_per_country.persist( + StorageLevel.MEMORY_ONLY_DESER, + StorageLevel.MEMORY_ONLY_DESER_2, + Seconds(1) + )*/ + + val counts_per_country = counts_per_content_per_country + .map(x => (x._1._1, (x._1._2, x._2))) + .groupByKey() + counts_per_country.print + + + def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { + implicit val countOrdering = new Ordering[(String, Int)] { + override def compare(count1: (String, Int), count2: (String, Int)): Int = { + count2._2 - count1._2 + } + } + val array = data.toArray + Sorting.quickSort(array) + val taken = array.take(k) + taken + } + + val k = 10 + val topKContents_per_country = counts_per_country + .map(x => (x._1, topK(x._2, k))) + .map(x => (x._1, x._2.map(_.toString).reduceLeft(_ + ", " + _))) + + topKContents_per_country.print + + ssc.run + } +} + + + diff --git a/streaming/src/main/scala/spark/stream/TopKWordCount2.scala b/streaming/src/main/scala/spark/stream/TopKWordCount2.scala new file mode 100644 index 0000000000..7dd06dd5ee --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TopKWordCount2.scala @@ -0,0 +1,103 @@ +package spark.stream + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting + +object TopKWordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) + + def topK(data: Iterator[(String, Int)], k: Int): Iterator[(String, Int)] = { + val taken = new Array[(String, Int)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, Int) = null + var swap: (String, Int) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + println("count = " + count) + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 10 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + + /* + windowedCounts.filter(_ == null).foreachRDD(rdd => { + val count = rdd.count + println("# of nulls = " + count) + })*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/stream/TopKWordCount2_Special.scala b/streaming/src/main/scala/spark/stream/TopKWordCount2_Special.scala new file mode 100644 index 0000000000..e9f3f914ae --- /dev/null +++ b/streaming/src/main/scala/spark/stream/TopKWordCount2_Special.scala @@ -0,0 +1,142 @@ +package spark.stream + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.Queue + +import java.lang.{Long => JLong} + +object TopKWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "TopKWordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + /*moreWarmup(ssc.sc)*/ + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray + ) + + /*val words = sentences.flatMap(_.split(" "))*/ + + /*def add(v1: Int, v2: Int) = (v1 + v2) */ + /*def subtract(v1: Int, v2: Int) = (v1 - v2) */ + + def add(v1: JLong, v2: JLong) = (v1 + v2) + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } + + + val windowedCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Milliseconds(500), 10) + /*windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1))*/ + windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) + + def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { + val taken = new Array[(String, JLong)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, JLong) = null + var swap: (String, JLong) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + println("count = " + count) + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 50 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + + /* + windowedCounts.filter(_ == null).foreachRDD(rdd => { + val count = rdd.count + println("# of nulls = " + count) + })*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/stream/WindowedRDS.scala b/streaming/src/main/scala/spark/stream/WindowedRDS.scala new file mode 100644 index 0000000000..a2e7966edb --- /dev/null +++ b/streaming/src/main/scala/spark/stream/WindowedRDS.scala @@ -0,0 +1,68 @@ +package spark.stream + +import spark.stream.SparkStreamContext._ + +import spark.RDD +import spark.UnionRDD +import spark.SparkContext._ + +import scala.collection.mutable.ArrayBuffer + +class WindowedRDS[T: ClassManifest]( + parent: RDS[T], + _windowTime: Time, + _slideTime: Time) + extends RDS[T](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of WindowedRDS (" + _slideTime + ") " + + "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of WindowedRDS (" + _slideTime + ") " + + "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") + + val allowPartialWindows = true + + override def dependencies = List(parent) + + def windowTime: Time = _windowTime + + override def slideTime: Time = _slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val parentRDDs = new ArrayBuffer[RDD[T]]() + val windowEndTime = validTime.copy() + val windowStartTime = if (allowPartialWindows && windowEndTime - windowTime < parent.zeroTime) { + parent.zeroTime + } else { + windowEndTime - windowTime + } + + logInfo("Window = " + windowStartTime + " - " + windowEndTime) + logInfo("Parent.zeroTime = " + parent.zeroTime) + + if (windowStartTime >= parent.zeroTime) { + // Walk back through time, from the 'windowEndTime' to 'windowStartTime' + // and get all parent RDDs from the parent RDS + var t = windowEndTime + while (t > windowStartTime) { + parent.getOrCompute(t) match { + case Some(rdd) => parentRDDs += rdd + case None => throw new Exception("Could not generate parent RDD for time " + t) + } + t -= parent.slideTime + } + } + + // Do a union of all parent RDDs to generate the new RDD + if (parentRDDs.size > 0) { + Some(new UnionRDD(ssc.sc, parentRDDs)) + } else { + None + } + } +} + + + diff --git a/streaming/src/main/scala/spark/stream/WordCount.scala b/streaming/src/main/scala/spark/stream/WordCount.scala new file mode 100644 index 0000000000..af825e46a8 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/WordCount.scala @@ -0,0 +1,62 @@ +package spark.stream + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object WordCount { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: WordCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "WordCountWindow") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) + //sentences.print + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), + // System.getProperty("spark.default.parallelism", "1").toInt) + //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) + //windowedCounts.print + + val parallelism = System.getProperty("spark.default.parallelism", "1").toInt + + //val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) + //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) + //val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(_ + _) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), + parallelism) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) + + //windowedCounts.print + windowedCounts.register + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/stream/WordCount1.scala b/streaming/src/main/scala/spark/stream/WordCount1.scala new file mode 100644 index 0000000000..501062b18d --- /dev/null +++ b/streaming/src/main/scala/spark/stream/WordCount1.scala @@ -0,0 +1,46 @@ +package spark.stream + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object WordCount1 { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: WordCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "WordCountWindow") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + //sentences.print + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/stream/WordCount2.scala b/streaming/src/main/scala/spark/stream/WordCount2.scala new file mode 100644 index 0000000000..24324e891a --- /dev/null +++ b/streaming/src/main/scala/spark/stream/WordCount2.scala @@ -0,0 +1,55 @@ +package spark.stream + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting + +object WordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + /*moreWarmup(ssc.sc)*/ + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 6) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/stream/WordCount2_Special.scala b/streaming/src/main/scala/spark/stream/WordCount2_Special.scala new file mode 100644 index 0000000000..c6b1aaa57e --- /dev/null +++ b/streaming/src/main/scala/spark/stream/WordCount2_Special.scala @@ -0,0 +1,94 @@ +package spark.stream + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import java.lang.{Long => JLong} +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + + +object WordCount2_ExtraFunctions { + + def add(v1: JLong, v2: JLong) = (v1 + v2) + + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } +} + +object WordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + + GrepCount2.warmConnectionManagers(ssc.sc) + /*moreWarmup(ssc.sc)*/ + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray + ) + + val windowedCounts = sentences + .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions) + .reduceByKeyAndWindow(WordCount2_ExtraFunctions.add _, WordCount2_ExtraFunctions.subtract _, Seconds(10), Milliseconds(500), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/stream/WordCount3.scala b/streaming/src/main/scala/spark/stream/WordCount3.scala new file mode 100644 index 0000000000..455a8c9dbf --- /dev/null +++ b/streaming/src/main/scala/spark/stream/WordCount3.scala @@ -0,0 +1,49 @@ +package spark.stream + +import SparkStreamContext._ + +import scala.util.Sorting + +object WordCount3 { + + def main (args: Array[String]) { + + if (args.length < 1) { + println ("Usage: SparkStreamContext []") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount") + if (args.length > 1) { + ssc.setTempDir(args(1)) + } + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + /*sentences.print*/ + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + /*val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1)*/ + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), 1) + /*windowedCounts.print */ + + def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { + implicit val countOrdering = new Ordering[(String, Int)] { + override def compare(count1: (String, Int), count2: (String, Int)): Int = { + count2._2 - count1._2 + } + } + val array = data.toArray + Sorting.quickSort(array) + array.take(k) + } + + val k = 10 + val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) + topKWindowedCounts.print + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/stream/WordCountEc2.scala b/streaming/src/main/scala/spark/stream/WordCountEc2.scala new file mode 100644 index 0000000000..5b10026d7a --- /dev/null +++ b/streaming/src/main/scala/spark/stream/WordCountEc2.scala @@ -0,0 +1,41 @@ +package spark.stream + +import SparkStreamContext._ +import spark.SparkContext + +object WordCountEc2 { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: SparkStreamContext ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val ssc = new SparkStreamContext(args(0), "Test") + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + /*sentences.foreach(println)*/ + + val words = sentences.flatMap(_.split(" ")) + /*words.foreach(println)*/ + + val counts = words.map(x => (x, 1)).reduceByKey(_ + _) + /*counts.foreach(println)*/ + + counts.foreachRDD(rdd => rdd.collect.foreach(x => x)) + /*counts.register*/ + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/stream/WordCountTrivialWindow.scala b/streaming/src/main/scala/spark/stream/WordCountTrivialWindow.scala new file mode 100644 index 0000000000..5469df71e9 --- /dev/null +++ b/streaming/src/main/scala/spark/stream/WordCountTrivialWindow.scala @@ -0,0 +1,51 @@ +package spark.stream + +import SparkStreamContext._ + +import scala.util.Sorting + +object WordCountTrivialWindow { + + def main (args: Array[String]) { + + if (args.length < 1) { + println ("Usage: SparkStreamContext []") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCountTrivialWindow") + if (args.length > 1) { + ssc.setTempDir(args(1)) + } + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + /*sentences.print*/ + + val words = sentences.flatMap(_.split(" ")) + + /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1)*/ + /*counts.print*/ + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1) + /*windowedCounts.print */ + + def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { + implicit val countOrdering = new Ordering[(String, Int)] { + override def compare(count1: (String, Int), count2: (String, Int)): Int = { + count2._2 - count1._2 + } + } + val array = data.toArray + Sorting.quickSort(array) + array.take(k) + } + + val k = 10 + val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) + topKWindowedCounts.print + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/stream/WordMax.scala b/streaming/src/main/scala/spark/stream/WordMax.scala new file mode 100644 index 0000000000..fc075e6d9d --- /dev/null +++ b/streaming/src/main/scala/spark/stream/WordMax.scala @@ -0,0 +1,64 @@ +package spark.stream + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object WordMax { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: WordCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "WordCountWindow") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) + //sentences.print + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + def max(v1: Int, v2: Int) = (if (v1 > v2) v1 else v2) + + //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), + // System.getProperty("spark.default.parallelism", "1").toInt) + //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) + //windowedCounts.print + + val parallelism = System.getProperty("spark.default.parallelism", "1").toInt + + val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) + //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER) + localCounts.persist(StorageLevel.MEMORY_ONLY_DESER_2) + val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(max _) + + //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), + // parallelism) + //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) + + //windowedCounts.print + windowedCounts.register + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) + + ssc.run + } +} -- cgit v1.2.3 From fcee4153b92bdd66dd92820a2670b339f9f59c77 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 29 Jul 2012 13:35:22 -0700 Subject: Renamed stream package to streaming --- startTrigger.sh | 2 +- .../src/main/scala/spark/stream/BlockID.scala | 20 - .../scala/spark/stream/ConnectionHandler.scala | 157 ------ .../spark/stream/DumbTopKWordCount2_Special.scala | 138 ----- .../spark/stream/DumbWordCount2_Special.scala | 92 ---- .../scala/spark/stream/FileStreamReceiver.scala | 70 --- .../src/main/scala/spark/stream/GrepCount.scala | 39 -- .../src/main/scala/spark/stream/GrepCount2.scala | 113 ---- .../main/scala/spark/stream/GrepCountApprox.scala | 54 -- .../main/scala/spark/stream/IdealPerformance.scala | 36 -- .../src/main/scala/spark/stream/Interval.scala | 75 --- streaming/src/main/scala/spark/stream/Job.scala | 21 - .../src/main/scala/spark/stream/JobManager.scala | 112 ---- .../src/main/scala/spark/stream/JobManager2.scala | 37 -- .../scala/spark/stream/NetworkStreamReceiver.scala | 184 ------- streaming/src/main/scala/spark/stream/RDS.scala | 607 --------------------- .../scala/spark/stream/ReducedWindowedRDS.scala | 218 -------- .../src/main/scala/spark/stream/Scheduler.scala | 181 ------ .../stream/SenGeneratorForPerformanceTest.scala | 78 --- .../scala/spark/stream/SenderReceiverTest.scala | 63 --- .../scala/spark/stream/SentenceFileGenerator.scala | 92 ---- .../scala/spark/stream/SentenceGenerator.scala | 103 ---- .../src/main/scala/spark/stream/ShuffleTest.scala | 22 - .../main/scala/spark/stream/SimpleWordCount.scala | 30 - .../main/scala/spark/stream/SimpleWordCount2.scala | 51 -- .../spark/stream/SimpleWordCount2_Special.scala | 83 --- .../scala/spark/stream/SparkStreamContext.scala | 105 ---- .../main/scala/spark/stream/TestGenerator.scala | 107 ---- .../main/scala/spark/stream/TestGenerator2.scala | 119 ---- .../main/scala/spark/stream/TestGenerator4.scala | 244 --------- .../scala/spark/stream/TestInputBlockTracker.scala | 42 -- .../scala/spark/stream/TestStreamCoordinator.scala | 38 -- .../scala/spark/stream/TestStreamReceiver3.scala | 420 -------------- .../scala/spark/stream/TestStreamReceiver4.scala | 373 ------------- streaming/src/main/scala/spark/stream/Time.scala | 85 --- .../main/scala/spark/stream/TopContentCount.scala | 97 ---- .../main/scala/spark/stream/TopKWordCount2.scala | 103 ---- .../spark/stream/TopKWordCount2_Special.scala | 142 ----- .../src/main/scala/spark/stream/WindowedRDS.scala | 68 --- .../src/main/scala/spark/stream/WordCount.scala | 62 --- .../src/main/scala/spark/stream/WordCount1.scala | 46 -- .../src/main/scala/spark/stream/WordCount2.scala | 55 -- .../scala/spark/stream/WordCount2_Special.scala | 94 ---- .../src/main/scala/spark/stream/WordCount3.scala | 49 -- .../src/main/scala/spark/stream/WordCountEc2.scala | 41 -- .../spark/stream/WordCountTrivialWindow.scala | 51 -- .../src/main/scala/spark/stream/WordMax.scala | 64 --- .../src/main/scala/spark/streaming/BlockID.scala | 20 + .../scala/spark/streaming/ConnectionHandler.scala | 157 ++++++ .../streaming/DumbTopKWordCount2_Special.scala | 138 +++++ .../spark/streaming/DumbWordCount2_Special.scala | 92 ++++ .../scala/spark/streaming/FileStreamReceiver.scala | 70 +++ .../src/main/scala/spark/streaming/GrepCount.scala | 39 ++ .../main/scala/spark/streaming/GrepCount2.scala | 113 ++++ .../scala/spark/streaming/GrepCountApprox.scala | 54 ++ .../scala/spark/streaming/IdealPerformance.scala | 36 ++ .../src/main/scala/spark/streaming/Interval.scala | 75 +++ streaming/src/main/scala/spark/streaming/Job.scala | 21 + .../main/scala/spark/streaming/JobManager.scala | 112 ++++ .../main/scala/spark/streaming/JobManager2.scala | 37 ++ .../spark/streaming/NetworkStreamReceiver.scala | 184 +++++++ streaming/src/main/scala/spark/streaming/RDS.scala | 607 +++++++++++++++++++++ .../scala/spark/streaming/ReducedWindowedRDS.scala | 218 ++++++++ .../src/main/scala/spark/streaming/Scheduler.scala | 181 ++++++ .../streaming/SenGeneratorForPerformanceTest.scala | 78 +++ .../scala/spark/streaming/SenderReceiverTest.scala | 63 +++ .../spark/streaming/SentenceFileGenerator.scala | 92 ++++ .../scala/spark/streaming/SentenceGenerator.scala | 103 ++++ .../main/scala/spark/streaming/ShuffleTest.scala | 22 + .../scala/spark/streaming/SimpleWordCount.scala | 30 + .../scala/spark/streaming/SimpleWordCount2.scala | 51 ++ .../spark/streaming/SimpleWordCount2_Special.scala | 83 +++ .../scala/spark/streaming/SparkStreamContext.scala | 105 ++++ .../main/scala/spark/streaming/TestGenerator.scala | 107 ++++ .../scala/spark/streaming/TestGenerator2.scala | 119 ++++ .../scala/spark/streaming/TestGenerator4.scala | 244 +++++++++ .../spark/streaming/TestInputBlockTracker.scala | 42 ++ .../spark/streaming/TestStreamCoordinator.scala | 38 ++ .../spark/streaming/TestStreamReceiver3.scala | 420 ++++++++++++++ .../spark/streaming/TestStreamReceiver4.scala | 373 +++++++++++++ .../src/main/scala/spark/streaming/Time.scala | 85 +++ .../scala/spark/streaming/TopContentCount.scala | 97 ++++ .../scala/spark/streaming/TopKWordCount2.scala | 103 ++++ .../spark/streaming/TopKWordCount2_Special.scala | 142 +++++ .../main/scala/spark/streaming/WindowedRDS.scala | 68 +++ .../src/main/scala/spark/streaming/WordCount.scala | 62 +++ .../main/scala/spark/streaming/WordCount1.scala | 46 ++ .../main/scala/spark/streaming/WordCount2.scala | 55 ++ .../scala/spark/streaming/WordCount2_Special.scala | 94 ++++ .../main/scala/spark/streaming/WordCount3.scala | 49 ++ .../main/scala/spark/streaming/WordCountEc2.scala | 41 ++ .../spark/streaming/WordCountTrivialWindow.scala | 51 ++ .../src/main/scala/spark/streaming/WordMax.scala | 64 +++ 93 files changed, 5082 insertions(+), 5082 deletions(-) delete mode 100644 streaming/src/main/scala/spark/stream/BlockID.scala delete mode 100644 streaming/src/main/scala/spark/stream/ConnectionHandler.scala delete mode 100644 streaming/src/main/scala/spark/stream/DumbTopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/stream/DumbWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/stream/FileStreamReceiver.scala delete mode 100644 streaming/src/main/scala/spark/stream/GrepCount.scala delete mode 100644 streaming/src/main/scala/spark/stream/GrepCount2.scala delete mode 100644 streaming/src/main/scala/spark/stream/GrepCountApprox.scala delete mode 100644 streaming/src/main/scala/spark/stream/IdealPerformance.scala delete mode 100644 streaming/src/main/scala/spark/stream/Interval.scala delete mode 100644 streaming/src/main/scala/spark/stream/Job.scala delete mode 100644 streaming/src/main/scala/spark/stream/JobManager.scala delete mode 100644 streaming/src/main/scala/spark/stream/JobManager2.scala delete mode 100644 streaming/src/main/scala/spark/stream/NetworkStreamReceiver.scala delete mode 100644 streaming/src/main/scala/spark/stream/RDS.scala delete mode 100644 streaming/src/main/scala/spark/stream/ReducedWindowedRDS.scala delete mode 100644 streaming/src/main/scala/spark/stream/Scheduler.scala delete mode 100644 streaming/src/main/scala/spark/stream/SenGeneratorForPerformanceTest.scala delete mode 100644 streaming/src/main/scala/spark/stream/SenderReceiverTest.scala delete mode 100644 streaming/src/main/scala/spark/stream/SentenceFileGenerator.scala delete mode 100644 streaming/src/main/scala/spark/stream/SentenceGenerator.scala delete mode 100644 streaming/src/main/scala/spark/stream/ShuffleTest.scala delete mode 100644 streaming/src/main/scala/spark/stream/SimpleWordCount.scala delete mode 100644 streaming/src/main/scala/spark/stream/SimpleWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/stream/SimpleWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/stream/SparkStreamContext.scala delete mode 100644 streaming/src/main/scala/spark/stream/TestGenerator.scala delete mode 100644 streaming/src/main/scala/spark/stream/TestGenerator2.scala delete mode 100644 streaming/src/main/scala/spark/stream/TestGenerator4.scala delete mode 100644 streaming/src/main/scala/spark/stream/TestInputBlockTracker.scala delete mode 100644 streaming/src/main/scala/spark/stream/TestStreamCoordinator.scala delete mode 100644 streaming/src/main/scala/spark/stream/TestStreamReceiver3.scala delete mode 100644 streaming/src/main/scala/spark/stream/TestStreamReceiver4.scala delete mode 100644 streaming/src/main/scala/spark/stream/Time.scala delete mode 100644 streaming/src/main/scala/spark/stream/TopContentCount.scala delete mode 100644 streaming/src/main/scala/spark/stream/TopKWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/stream/TopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/stream/WindowedRDS.scala delete mode 100644 streaming/src/main/scala/spark/stream/WordCount.scala delete mode 100644 streaming/src/main/scala/spark/stream/WordCount1.scala delete mode 100644 streaming/src/main/scala/spark/stream/WordCount2.scala delete mode 100644 streaming/src/main/scala/spark/stream/WordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/stream/WordCount3.scala delete mode 100644 streaming/src/main/scala/spark/stream/WordCountEc2.scala delete mode 100644 streaming/src/main/scala/spark/stream/WordCountTrivialWindow.scala delete mode 100644 streaming/src/main/scala/spark/stream/WordMax.scala create mode 100644 streaming/src/main/scala/spark/streaming/BlockID.scala create mode 100644 streaming/src/main/scala/spark/streaming/ConnectionHandler.scala create mode 100644 streaming/src/main/scala/spark/streaming/DumbTopKWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/DumbWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/GrepCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/GrepCount2.scala create mode 100644 streaming/src/main/scala/spark/streaming/GrepCountApprox.scala create mode 100644 streaming/src/main/scala/spark/streaming/IdealPerformance.scala create mode 100644 streaming/src/main/scala/spark/streaming/Interval.scala create mode 100644 streaming/src/main/scala/spark/streaming/Job.scala create mode 100644 streaming/src/main/scala/spark/streaming/JobManager.scala create mode 100644 streaming/src/main/scala/spark/streaming/JobManager2.scala create mode 100644 streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/RDS.scala create mode 100644 streaming/src/main/scala/spark/streaming/ReducedWindowedRDS.scala create mode 100644 streaming/src/main/scala/spark/streaming/Scheduler.scala create mode 100644 streaming/src/main/scala/spark/streaming/SenGeneratorForPerformanceTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/SenderReceiverTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/SentenceFileGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/SentenceGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/ShuffleTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/SimpleWordCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/SimpleWordCount2.scala create mode 100644 streaming/src/main/scala/spark/streaming/SimpleWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/SparkStreamContext.scala create mode 100644 streaming/src/main/scala/spark/streaming/TestGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/TestGenerator2.scala create mode 100644 streaming/src/main/scala/spark/streaming/TestGenerator4.scala create mode 100644 streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala create mode 100644 streaming/src/main/scala/spark/streaming/TestStreamCoordinator.scala create mode 100644 streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala create mode 100644 streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala create mode 100644 streaming/src/main/scala/spark/streaming/Time.scala create mode 100644 streaming/src/main/scala/spark/streaming/TopContentCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/TopKWordCount2.scala create mode 100644 streaming/src/main/scala/spark/streaming/TopKWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/WindowedRDS.scala create mode 100644 streaming/src/main/scala/spark/streaming/WordCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/WordCount1.scala create mode 100644 streaming/src/main/scala/spark/streaming/WordCount2.scala create mode 100644 streaming/src/main/scala/spark/streaming/WordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/WordCount3.scala create mode 100644 streaming/src/main/scala/spark/streaming/WordCountEc2.scala create mode 100644 streaming/src/main/scala/spark/streaming/WordCountTrivialWindow.scala create mode 100644 streaming/src/main/scala/spark/streaming/WordMax.scala diff --git a/startTrigger.sh b/startTrigger.sh index 0afce91a3e..373dbda93e 100755 --- a/startTrigger.sh +++ b/startTrigger.sh @@ -1,3 +1,3 @@ #!/bin/bash -./run spark.stream.SentenceGenerator localhost 7078 sentences.txt 1 +./run spark.streaming.SentenceGenerator localhost 7078 sentences.txt 1 diff --git a/streaming/src/main/scala/spark/stream/BlockID.scala b/streaming/src/main/scala/spark/stream/BlockID.scala deleted file mode 100644 index a3fd046c9a..0000000000 --- a/streaming/src/main/scala/spark/stream/BlockID.scala +++ /dev/null @@ -1,20 +0,0 @@ -package spark.stream - -case class BlockID(sRds: String, sInterval: Interval, sPartition: Int) { - override def toString : String = ( - sRds + BlockID.sConnector + - sInterval.beginTime + BlockID.sConnector + - sInterval.endTime + BlockID.sConnector + - sPartition - ) -} - -object BlockID { - val sConnector = '-' - - def parse(name : String) = BlockID( - name.split(BlockID.sConnector)(0), - new Interval(name.split(BlockID.sConnector)(1).toLong, - name.split(BlockID.sConnector)(2).toLong), - name.split(BlockID.sConnector)(3).toInt) -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/stream/ConnectionHandler.scala b/streaming/src/main/scala/spark/stream/ConnectionHandler.scala deleted file mode 100644 index 73b82b76b8..0000000000 --- a/streaming/src/main/scala/spark/stream/ConnectionHandler.scala +++ /dev/null @@ -1,157 +0,0 @@ -package spark.stream - -import spark.Logging - -import scala.collection.mutable.{ArrayBuffer, SynchronizedQueue} - -import java.net._ -import java.io._ -import java.nio._ -import java.nio.charset._ -import java.nio.channels._ -import java.nio.channels.spi._ - -abstract class ConnectionHandler(host: String, port: Int, connect: Boolean) -extends Thread with Logging { - - val selector = SelectorProvider.provider.openSelector() - val interestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - - initLogging() - - override def run() { - try { - if (connect) { - connect() - } else { - listen() - } - - var interrupted = false - while(!interrupted) { - - preSelect() - - while(!interestChangeRequests.isEmpty) { - val (key, ops) = interestChangeRequests.dequeue - val lastOps = key.interestOps() - key.interestOps(ops) - - def intToOpStr(op: Int): String = { - val opStrs = new ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed ops from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - - selector.select() - interrupted = Thread.currentThread.isInterrupted - - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext) { - val key = selectedKeys.next.asInstanceOf[SelectionKey] - selectedKeys.remove() - if (key.isValid) { - if (key.isAcceptable) { - accept(key) - } else if (key.isConnectable) { - finishConnect(key) - } else if (key.isReadable) { - read(key) - } else if (key.isWritable) { - write(key) - } - } - } - } - } catch { - case e: Exception => { - logError("Error in select loop", e) - } - } - } - - def connect() { - val socketAddress = new InetSocketAddress(host, port) - val channel = SocketChannel.open() - channel.configureBlocking(false) - channel.socket.setReuseAddress(true) - channel.socket.setTcpNoDelay(true) - channel.connect(socketAddress) - channel.register(selector, SelectionKey.OP_CONNECT) - logInfo("Initiating connection to [" + socketAddress + "]") - } - - def listen() { - val channel = ServerSocketChannel.open() - channel.configureBlocking(false) - channel.socket.setReuseAddress(true) - channel.socket.setReceiveBufferSize(256 * 1024) - channel.socket.bind(new InetSocketAddress(port)) - channel.register(selector, SelectionKey.OP_ACCEPT) - logInfo("Listening on port " + port) - } - - def finishConnect(key: SelectionKey) { - try { - val channel = key.channel.asInstanceOf[SocketChannel] - val address = channel.socket.getRemoteSocketAddress - channel.finishConnect() - logInfo("Connected to [" + host + ":" + port + "]") - ready(key) - } catch { - case e: IOException => { - logError("Error finishing connect to " + host + ":" + port) - close(key) - } - } - } - - def accept(key: SelectionKey) { - try { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - val channel = serverChannel.accept() - val address = channel.socket.getRemoteSocketAddress - channel.configureBlocking(false) - logInfo("Accepted connection from [" + address + "]") - ready(channel.register(selector, 0)) - } catch { - case e: IOException => { - logError("Error accepting connection", e) - } - } - } - - def changeInterest(key: SelectionKey, ops: Int) { - logTrace("Added request to change ops to " + ops) - interestChangeRequests += ((key, ops)) - } - - def ready(key: SelectionKey) - - def preSelect() { - } - - def read(key: SelectionKey) { - throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) - } - - def write(key: SelectionKey) { - throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) - } - - def close(key: SelectionKey) { - try { - key.channel.close() - key.cancel() - Thread.currentThread.interrupt - } catch { - case e: Exception => logError("Error closing connection", e) - } - } -} diff --git a/streaming/src/main/scala/spark/stream/DumbTopKWordCount2_Special.scala b/streaming/src/main/scala/spark/stream/DumbTopKWordCount2_Special.scala deleted file mode 100644 index bd43f44b1a..0000000000 --- a/streaming/src/main/scala/spark/stream/DumbTopKWordCount2_Special.scala +++ /dev/null @@ -1,138 +0,0 @@ -package spark.stream - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object DumbTopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - /*println("count = " + count)*/ - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/stream/DumbWordCount2_Special.scala b/streaming/src/main/scala/spark/stream/DumbWordCount2_Special.scala deleted file mode 100644 index 31d682348a..0000000000 --- a/streaming/src/main/scala/spark/stream/DumbWordCount2_Special.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.stream - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -object DumbWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - - map.toIterator - } - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/stream/FileStreamReceiver.scala b/streaming/src/main/scala/spark/stream/FileStreamReceiver.scala deleted file mode 100644 index 026254d6e1..0000000000 --- a/streaming/src/main/scala/spark/stream/FileStreamReceiver.scala +++ /dev/null @@ -1,70 +0,0 @@ -package spark.stream - -import spark.Logging - -import scala.collection.mutable.HashSet -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -class FileStreamReceiver ( - inputName: String, - rootDirectory: String, - intervalDuration: Long) - extends Logging { - - val pollInterval = 100 - val sparkstreamScheduler = { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt + 1 - RemoteActor.select(Node(host, port), 'SparkStreamScheduler) - } - val directory = new Path(rootDirectory) - val fs = directory.getFileSystem(new Configuration()) - val files = new HashSet[String]() - var time: Long = 0 - - def start() { - fs.mkdirs(directory) - files ++= getFiles() - - actor { - logInfo("Monitoring directory - " + rootDirectory) - while(true) { - testFiles(getFiles()) - Thread.sleep(pollInterval) - } - } - } - - def getFiles(): Iterable[String] = { - fs.listStatus(directory).map(_.getPath.toString) - } - - def testFiles(fileList: Iterable[String]) { - fileList.foreach(file => { - if (!files.contains(file)) { - if (!file.endsWith("_tmp")) { - notifyFile(file) - } - files += file - } - }) - } - - def notifyFile(file: String) { - logInfo("Notifying file " + file) - time += intervalDuration - val interval = Interval(LongTime(time), LongTime(time + intervalDuration)) - sparkstreamScheduler ! InputGenerated(inputName, interval, file) - } -} - - diff --git a/streaming/src/main/scala/spark/stream/GrepCount.scala b/streaming/src/main/scala/spark/stream/GrepCount.scala deleted file mode 100644 index 45b90d4837..0000000000 --- a/streaming/src/main/scala/spark/stream/GrepCount.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: GrepCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/stream/GrepCount2.scala b/streaming/src/main/scala/spark/stream/GrepCount2.scala deleted file mode 100644 index 4eb65ba906..0000000000 --- a/streaming/src/main/scala/spark/stream/GrepCount2.scala +++ /dev/null @@ -1,113 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkEnv -import spark.SparkContext -import spark.storage.StorageLevel -import spark.network.Message -import spark.network.ConnectionManagerId - -import java.nio.ByteBuffer - -object GrepCount2 { - - def startSparkEnvs(sc: SparkContext) { - - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - println("SparkEnvs started") - Thread.sleep(1000) - /*sc.runJob(sc.parallelize(0 to 1000, 100), (_: Iterator[Int]) => {})*/ - } - - def warmConnectionManagers(sc: SparkContext) { - val slaveConnManagerIds = sc.parallelize(0 to 100, 100).map( - i => SparkEnv.get.connectionManager.id).collect().distinct - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - Thread.sleep(1000) - val numSlaves = slaveConnManagerIds.size - val count = 3 - val size = 5 * 1024 * 1024 - val iterations = (500 * 1024 * 1024 / (numSlaves * size)).toInt - println("count = " + count + ", size = " + size + ", iterations = " + iterations) - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until numSlaves, numSlaves).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - /*connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - })*/ - - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = (0 until iterations).map(i => { - slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - println("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - }) - }).flatMap(x => x) - val results = futures.map(f => f()) - val finishTime = System.currentTimeMillis - - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" - println(resultStr) - System.gc() - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } - - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "GrepCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - /*startSparkEnvs(ssc.sc)*/ - warmConnectionManagers(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-"+i, 500)).toArray - ) - - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} - - - - diff --git a/streaming/src/main/scala/spark/stream/GrepCountApprox.scala b/streaming/src/main/scala/spark/stream/GrepCountApprox.scala deleted file mode 100644 index a4be2cc936..0000000000 --- a/streaming/src/main/scala/spark/stream/GrepCountApprox.scala +++ /dev/null @@ -1,54 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCountApprox { - var inputFile : String = null - var hdfs : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 5) { - println ("Usage: GrepCountApprox ") - System.exit(1) - } - - hdfs = args(1) - inputFile = hdfs + args(2) - idealPartitions = args(3).toInt - val timeout = args(4).toLong - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - ssc.setTempDir(hdfs + "/tmp") - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - var i = 0 - val startTime = System.currentTimeMillis - matching.foreachRDD { rdd => - val myNum = i - val result = rdd.countApprox(timeout) - val initialTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tinitial\t%.1f\t%.1f\n", initialTime, myNum, result.initialValue.mean, - result.initialValue.high - result.initialValue.low) - result.onComplete { r => - val finalTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tfinal\t%.1f\t0.0\t%.1f\n", finalTime, myNum, r.mean, finalTime - initialTime) - } - i += 1 - } - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/stream/IdealPerformance.scala b/streaming/src/main/scala/spark/stream/IdealPerformance.scala deleted file mode 100644 index 589fb2def0..0000000000 --- a/streaming/src/main/scala/spark/stream/IdealPerformance.scala +++ /dev/null @@ -1,36 +0,0 @@ -package spark.stream - -import scala.collection.mutable.Map - -object IdealPerformance { - val base: String = "The medium researcher counts around the pinched troop The empire breaks " + - "Matei Matei announces HY with a theorem " - - def main (args: Array[String]) { - val sentences: String = base * 100000 - - for (i <- 1 to 30) { - val start = System.nanoTime - - val words = sentences.split(" ") - - val pairs = words.map(word => (word, 1)) - - val counts = Map[String, Int]() - - println("Job " + i + " position A at " + (System.nanoTime - start) / 1e9) - - pairs.foreach((pair) => { - var t = counts.getOrElse(pair._1, 0) - counts(pair._1) = t + pair._2 - }) - println("Job " + i + " position B at " + (System.nanoTime - start) / 1e9) - - for ((word, count) <- counts) { - print(word + " " + count + "; ") - } - println - println("Job " + i + " finished in " + (System.nanoTime - start) / 1e9) - } - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/stream/Interval.scala b/streaming/src/main/scala/spark/stream/Interval.scala deleted file mode 100644 index 08d0ed95b4..0000000000 --- a/streaming/src/main/scala/spark/stream/Interval.scala +++ /dev/null @@ -1,75 +0,0 @@ -package spark.stream - -case class Interval (val beginTime: Time, val endTime: Time) { - - def this(beginMs: Long, endMs: Long) = this(new LongTime(beginMs), new LongTime(endMs)) - - def duration(): Time = endTime - beginTime - - def += (time: Time) { - beginTime += time - endTime += time - this - } - - def + (time: Time): Interval = { - new Interval(beginTime + time, endTime + time) - } - - def < (that: Interval): Boolean = { - if (this.duration != that.duration) { - throw new Exception("Comparing two intervals with different durations [" + this + ", " + that + "]") - } - this.endTime < that.endTime - } - - def <= (that: Interval) = (this < that || this == that) - - def > (that: Interval) = !(this <= that) - - def >= (that: Interval) = !(this < that) - - def next(): Interval = { - this + (endTime - beginTime) - } - - def isZero() = (beginTime.isZero && endTime.isZero) - - def toFormattedString = beginTime.toFormattedString + "-" + endTime.toFormattedString - - override def toString = "[" + beginTime + ", " + endTime + "]" -} - -object Interval { - - /* - implicit def longTupleToInterval (longTuple: (Long, Long)) = - Interval(longTuple._1, longTuple._2) - - implicit def intTupleToInterval (intTuple: (Int, Int)) = - Interval(intTuple._1, intTuple._2) - - implicit def string2Interval (str: String): Interval = { - val parts = str.split(",") - if (parts.length == 1) - return Interval.zero - return Interval (parts(0).toInt, parts(1).toInt) - } - - def getInterval (timeMs: Long, intervalDurationMs: Long): Interval = { - val intervalBeginMs = timeMs / intervalDurationMs * intervalDurationMs - Interval(intervalBeginMs, intervalBeginMs + intervalDurationMs) - } - */ - - def zero() = new Interval (Time.zero, Time.zero) - - def currentInterval(intervalDuration: LongTime): Interval = { - val time = LongTime(System.currentTimeMillis) - val intervalBegin = time.floor(intervalDuration) - Interval(intervalBegin, intervalBegin + intervalDuration) - } - -} - - diff --git a/streaming/src/main/scala/spark/stream/Job.scala b/streaming/src/main/scala/spark/stream/Job.scala deleted file mode 100644 index bfdd5db645..0000000000 --- a/streaming/src/main/scala/spark/stream/Job.scala +++ /dev/null @@ -1,21 +0,0 @@ -package spark.stream - -class Job(val time: Time, func: () => _) { - val id = Job.getNewId() - - def run() { - func() - } - - override def toString = "SparkStream Job " + id + ":" + time -} - -object Job { - var lastId = 1 - - def getNewId() = synchronized { - lastId += 1 - lastId - } -} - diff --git a/streaming/src/main/scala/spark/stream/JobManager.scala b/streaming/src/main/scala/spark/stream/JobManager.scala deleted file mode 100644 index 5ea80b92aa..0000000000 --- a/streaming/src/main/scala/spark/stream/JobManager.scala +++ /dev/null @@ -1,112 +0,0 @@ -package spark.stream - -import spark.SparkEnv -import spark.Logging - -import scala.collection.mutable.PriorityQueue -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ -import scala.actors.scheduler.ResizableThreadPoolScheduler -import scala.actors.scheduler.ForkJoinScheduler - -sealed trait JobManagerMessage -case class RunJob(job: Job) extends JobManagerMessage -case class JobCompleted(handlerId: Int) extends JobManagerMessage - -class JobHandler(ssc: SparkStreamContext, val id: Int) extends DaemonActor with Logging { - - var busy = false - - def act() { - loop { - receive { - case job: Job => { - SparkEnv.set(ssc.env) - try { - logInfo("Starting " + job) - job.run() - logInfo("Finished " + job) - if (job.time.isInstanceOf[LongTime]) { - val longTime = job.time.asInstanceOf[LongTime] - logInfo("Total pushing + skew + processing delay for " + longTime + " is " + - (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") - } - } catch { - case e: Exception => logError("SparkStream job failed", e) - } - busy = false - reply(JobCompleted(id)) - } - } - } - } -} - -class JobManager(ssc: SparkStreamContext, numThreads: Int = 2) extends DaemonActor with Logging { - - implicit private val jobOrdering = new Ordering[Job] { - override def compare(job1: Job, job2: Job): Int = { - if (job1.time < job2.time) { - return 1 - } else if (job2.time < job1.time) { - return -1 - } else { - return 0 - } - } - } - - private val jobs = new PriorityQueue[Job]() - private val handlers = (0 until numThreads).map(i => new JobHandler(ssc, i)) - - def act() { - handlers.foreach(_.start) - loop { - receive { - case RunJob(job) => { - jobs += job - logInfo("Job " + job + " submitted") - runJob() - } - case JobCompleted(handlerId) => { - runJob() - } - } - } - } - - def runJob(): Unit = { - logInfo("Attempting to allocate job ") - if (jobs.size > 0) { - handlers.find(!_.busy).foreach(handler => { - val job = jobs.dequeue - logInfo("Allocating job " + job + " to handler " + handler.id) - handler.busy = true - handler ! job - }) - } - } -} - -object JobManager { - def main(args: Array[String]) { - val ssc = new SparkStreamContext("local[4]", "JobManagerTest") - val jobManager = new JobManager(ssc) - jobManager.start() - - val t = System.currentTimeMillis - for (i <- 1 to 10) { - jobManager ! RunJob(new Job( - LongTime(i), - () => { - Thread.sleep(500) - println("Job " + i + " took " + (System.currentTimeMillis - t) + " ms") - } - )) - } - Thread.sleep(6000) - } -} - diff --git a/streaming/src/main/scala/spark/stream/JobManager2.scala b/streaming/src/main/scala/spark/stream/JobManager2.scala deleted file mode 100644 index b69653b9a4..0000000000 --- a/streaming/src/main/scala/spark/stream/JobManager2.scala +++ /dev/null @@ -1,37 +0,0 @@ -package spark.stream - -import spark.{Logging, SparkEnv} -import java.util.concurrent.Executors - - -class JobManager2(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { - - class JobHandler(ssc: SparkStreamContext, job: Job) extends Runnable { - def run() { - SparkEnv.set(ssc.env) - try { - logInfo("Starting " + job) - job.run() - logInfo("Finished " + job) - if (job.time.isInstanceOf[LongTime]) { - val longTime = job.time.asInstanceOf[LongTime] - logInfo("Total notification + skew + processing delay for " + longTime + " is " + - (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") - if (System.getProperty("spark.stream.distributed", "false") == "true") { - TestInputBlockTracker.setEndTime(job.time) - } - } - } catch { - case e: Exception => logError("SparkStream job failed", e) - } - } - } - - initLogging() - - val jobExecutor = Executors.newFixedThreadPool(numThreads) - - def runJob(job: Job) { - jobExecutor.execute(new JobHandler(ssc, job)) - } -} diff --git a/streaming/src/main/scala/spark/stream/NetworkStreamReceiver.scala b/streaming/src/main/scala/spark/stream/NetworkStreamReceiver.scala deleted file mode 100644 index 8be46cc927..0000000000 --- a/streaming/src/main/scala/spark/stream/NetworkStreamReceiver.scala +++ /dev/null @@ -1,184 +0,0 @@ -package spark.stream - -import spark.Logging -import spark.storage.StorageLevel - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer} -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.io.BufferedWriter -import java.io.OutputStreamWriter - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -/*import akka.actor.Actor._*/ - -class NetworkStreamReceiver[T: ClassManifest] ( - inputName: String, - intervalDuration: Time, - splitId: Int, - ssc: SparkStreamContext, - tempDirectory: String) - extends DaemonActor - with Logging { - - /** - * Assume all data coming in has non-decreasing timestamp. - */ - final class Inbox[T: ClassManifest] (intervalDuration: Time) { - var currentBucket: (Interval, ArrayBuffer[T]) = null - val filledBuckets = new Queue[(Interval, ArrayBuffer[T])]() - - def += (tuple: (Time, T)) = addTuple(tuple) - - def addTuple(tuple: (Time, T)) { - val (time, data) = tuple - val interval = getInterval (time) - - filledBuckets.synchronized { - if (currentBucket == null) { - currentBucket = (interval, new ArrayBuffer[T]()) - } - - if (interval != currentBucket._1) { - filledBuckets += currentBucket - currentBucket = (interval, new ArrayBuffer[T]()) - } - - currentBucket._2 += data - } - } - - def getInterval(time: Time): Interval = { - val intervalBegin = time.floor(intervalDuration) - Interval (intervalBegin, intervalBegin + intervalDuration) - } - - def hasFilledBuckets(): Boolean = { - filledBuckets.synchronized { - return filledBuckets.size > 0 - } - } - - def popFilledBucket(): (Interval, ArrayBuffer[T]) = { - filledBuckets.synchronized { - if (filledBuckets.size == 0) { - return null - } - return filledBuckets.dequeue() - } - } - } - - val inbox = new Inbox[T](intervalDuration) - lazy val sparkstreamScheduler = { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - val url = "akka://spark@%s:%s/user/SparkStreamScheduler".format(host, port) - ssc.actorSystem.actorFor(url) - } - /*sparkstreamScheduler ! Test()*/ - - val intervalDurationMillis = intervalDuration.asInstanceOf[LongTime].milliseconds - val useBlockManager = true - - initLogging() - - override def act() { - // register the InputReceiver - val port = 7078 - RemoteActor.alive(port) - RemoteActor.register(Symbol("NetworkStreamReceiver-"+inputName), self) - logInfo("Registered actor on port " + port) - - loop { - reactWithin (getSleepTime) { - case TIMEOUT => - flushInbox() - case data => - val t = data.asInstanceOf[T] - inbox += (getTimeFromData(t), t) - } - } - } - - def getSleepTime(): Long = { - (System.currentTimeMillis / intervalDurationMillis + 1) * - intervalDurationMillis - System.currentTimeMillis - } - - def getTimeFromData(data: T): Time = { - LongTime(System.currentTimeMillis) - } - - def flushInbox() { - while (inbox.hasFilledBuckets) { - inbox.synchronized { - val (interval, data) = inbox.popFilledBucket() - val dataArray = data.toArray - logInfo("Received " + dataArray.length + " items at interval " + interval) - val reference = { - if (useBlockManager) { - writeToBlockManager(dataArray, interval) - } else { - writeToDisk(dataArray, interval) - } - } - if (reference != null) { - logInfo("Notifying scheduler") - sparkstreamScheduler ! InputGenerated(inputName, interval, reference.toString) - } - } - } - } - - def writeToDisk(data: Array[T], interval: Interval): String = { - try { - // TODO(Haoyuan): For current test, the following writing to file lines could be - // commented. - val fs = new Path(tempDirectory).getFileSystem(new Configuration()) - val inputDir = new Path( - tempDirectory, - inputName + "-" + interval.toFormattedString) - val inputFile = new Path(inputDir, "part-" + splitId) - logInfo("Writing to file " + inputFile) - if (System.getProperty("spark.fake", "false") != "true") { - val writer = new BufferedWriter(new OutputStreamWriter(fs.create(inputFile, true))) - data.foreach(x => writer.write(x.toString + "\n")) - writer.close() - } else { - logInfo("Fake file") - } - inputFile.toString - }catch { - case e: Exception => - logError("Exception writing to file at interval " + interval + ": " + e.getMessage, e) - null - } - } - - def writeToBlockManager(data: Array[T], interval: Interval): String = { - try{ - val blockId = inputName + "-" + interval.toFormattedString + "-" + splitId - if (System.getProperty("spark.fake", "false") != "true") { - logInfo("Writing as block " + blockId ) - ssc.env.blockManager.put(blockId.toString, data.toIterator, StorageLevel.DISK_AND_MEMORY) - } else { - logInfo("Fake block") - } - blockId - } catch { - case e: Exception => - logError("Exception writing to block manager at interval " + interval + ": " + e.getMessage, e) - null - } - } -} diff --git a/streaming/src/main/scala/spark/stream/RDS.scala b/streaming/src/main/scala/spark/stream/RDS.scala deleted file mode 100644 index b83181b0d1..0000000000 --- a/streaming/src/main/scala/spark/stream/RDS.scala +++ /dev/null @@ -1,607 +0,0 @@ -package spark.stream - -import spark.stream.SparkStreamContext._ - -import spark.RDD -import spark.BlockRDD -import spark.UnionRDD -import spark.Logging -import spark.SparkContext -import spark.SparkContext._ -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import java.net.InetSocketAddress - -abstract class RDS[T: ClassManifest] (@transient val ssc: SparkStreamContext) -extends Logging with Serializable { - - initLogging() - - /* ---------------------------------------------- */ - /* Methods that must be implemented by subclasses */ - /* ---------------------------------------------- */ - - // Time by which the window slides in this RDS - def slideTime: Time - - // List of parent RDSs on which this RDS depends on - def dependencies: List[RDS[_]] - - // Key method that computes RDD for a valid time - def compute (validTime: Time): Option[RDD[T]] - - /* --------------------------------------- */ - /* Other general fields and methods of RDS */ - /* --------------------------------------- */ - - // Variable to store the RDDs generated earlier in time - @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () - - // Variable to be set to the first time seen by the RDS (effective time zero) - private[stream] var zeroTime: Time = null - - // Variable to specify storage level - private var storageLevel: StorageLevel = StorageLevel.NONE - - // Checkpoint level and checkpoint interval - private var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - private var checkpointInterval: Time = null - - // Change this RDD's storage level - def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): RDS[T] = { - if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { - // TODO: not sure this is necessary for RDSes - throw new UnsupportedOperationException( - "Cannot change storage level of an RDS after it was already assigned a level") - } - this.storageLevel = storageLevel - this.checkpointLevel = checkpointLevel - this.checkpointInterval = checkpointInterval - this - } - - def persist(newLevel: StorageLevel): RDS[T] = persist(newLevel, StorageLevel.NONE, null) - - // Turn on the default caching level for this RDD - def persist(): RDS[T] = persist(StorageLevel.MEMORY_ONLY_DESER) - - // Turn on the default caching level for this RDD - def cache(): RDS[T] = persist() - - def isInitialized = (zeroTime != null) - - // This method initializes the RDS by setting the "zero" time, based on which - // the validity of future times is calculated. This method also recursively initializes - // its parent RDSs. - def initialize(firstInterval: Interval) { - if (zeroTime == null) { - zeroTime = firstInterval.beginTime - } - logInfo(this + " initialized") - dependencies.foreach(_.initialize(firstInterval)) - } - - // This method checks whether the 'time' is valid wrt slideTime for generating RDD - private def isTimeValid (time: Time): Boolean = { - if (!isInitialized) - throw new Exception (this.toString + " has not been initialized") - if ((time - zeroTime).isMultipleOf(slideTime)) { - true - } else { - false - } - } - - // This method either retrieves a precomputed RDD of this RDS, - // or computes the RDD (if the time is valid) - def getOrCompute(time: Time): Option[RDD[T]] = { - - // if RDD was already generated, then retrieve it from HashMap - generatedRDDs.get(time) match { - - // If an RDD was already generated and is being reused, then - // probably all RDDs in this RDS will be reused and hence should be cached - case Some(oldRDD) => Some(oldRDD) - - // if RDD was not generated, and if the time is valid - // (based on sliding time of this RDS), then generate the RDD - case None => - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => - if (System.getProperty("spark.fake", "false") != "true" || - newRDD.getStorageLevel == StorageLevel.NONE) { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } - } - generatedRDDs.put(time.copy(), newRDD) - Some(newRDD) - case None => - None - } - } else { - None - } - } - } - - // This method generates a SparkStream job for the given time - // and may require to be overriden by subclasses - def generateJob(time: Time): Option[Job] = { - getOrCompute(time) match { - case Some(rdd) => { - val jobFunc = () => { - val emptyFunc = { (iterator: Iterator[T]) => {} } - ssc.sc.runJob(rdd, emptyFunc) - } - Some(new Job(time, jobFunc)) - } - case None => None - } - } - - /* -------------- */ - /* RDS operations */ - /* -------------- */ - - def map[U: ClassManifest](mapFunc: T => U) = new MappedRDS(this, ssc.sc.clean(mapFunc)) - - def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = - new FlatMappedRDS(this, ssc.sc.clean(flatMapFunc)) - - def filter(filterFunc: T => Boolean) = new FilteredRDS(this, filterFunc) - - def glom() = new GlommedRDS(this) - - def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = - new MapPartitionedRDS(this, ssc.sc.clean(mapPartFunc)) - - def reduce(reduceFunc: (T, T) => T) = this.map(x => (1, x)).reduceByKey(reduceFunc, 1).map(_._2) - - def count() = this.map(_ => 1).reduce(_ + _) - - def collect() = this.map(x => (1, x)).groupByKey(1).map(_._2) - - def foreach(foreachFunc: T => Unit) = { - val newrds = new PerElementForEachRDS(this, ssc.sc.clean(foreachFunc)) - ssc.registerOutputStream(newrds) - newrds - } - - def foreachRDD(foreachFunc: RDD[T] => Unit) = { - val newrds = new PerRDDForEachRDS(this, ssc.sc.clean(foreachFunc)) - ssc.registerOutputStream(newrds) - newrds - } - - def print() = { - def foreachFunc = (rdd: RDD[T], time: Time) => { - val first11 = rdd.take(11) - println ("-------------------------------------------") - println ("Time: " + time) - println ("-------------------------------------------") - first11.take(10).foreach(println) - if (first11.size > 10) println("...") - println() - } - val newrds = new PerRDDForEachRDS(this, ssc.sc.clean(foreachFunc)) - ssc.registerOutputStream(newrds) - newrds - } - - def window(windowTime: Time, slideTime: Time) = new WindowedRDS(this, windowTime, slideTime) - - def batch(batchTime: Time) = window(batchTime, batchTime) - - def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time) = - this.window(windowTime, slideTime).reduce(reduceFunc) - - def reduceByWindow( - reduceFunc: (T, T) => T, - invReduceFunc: (T, T) => T, - windowTime: Time, - slideTime: Time) = { - this.map(x => (1, x)) - .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1) - .map(_._2) - } - - def countByWindow(windowTime: Time, slideTime: Time) = { - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime) - } - - def union(that: RDS[T]) = new UnifiedRDS(Array(this, that)) - - def register() = ssc.registerOutputStream(this) -} - - -class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) -extends Serializable { - - def ssc = rds.ssc - - /* ---------------------------------- */ - /* RDS operations for key-value pairs */ - /* ---------------------------------- */ - - def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - def createCombiner(v: V) = ArrayBuffer[V](v) - def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) - def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) - combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) - } - - private def combineByKey[C: ClassManifest]( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - numPartitions: Int) : ShuffledRDS[K, V, C] = { - new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def groupByKeyAndWindow( - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - rds.window(windowTime, slideTime).groupByKey(numPartitions) - } - - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) - } - - // This method is the efficient sliding window reduce operation, - // which requires the specification of an inverse reduce function, - // so that new elements introduced in the window can be "added" using - // reduceFunc to the previous window's result and old elements can be - // "subtracted using invReduceFunc. - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int): ReducedWindowedRDS[K, V] = { - - new ReducedWindowedRDS[K, V]( - rds, - ssc.sc.clean(reduceFunc), - ssc.sc.clean(invReduceFunc), - windowTime, - slideTime, - numPartitions) - } -} - - -abstract class InputRDS[T: ClassManifest] ( - val inputName: String, - val batchDuration: Time, - ssc: SparkStreamContext) -extends RDS[T](ssc) { - - override def dependencies = List() - - override def slideTime = batchDuration - - def setReference(time: Time, reference: AnyRef) -} - - -class FileInputRDS( - val fileInputName: String, - val directory: String, - ssc: SparkStreamContext) -extends InputRDS[String](fileInputName, LongTime(1000), ssc) { - - @transient val generatedFiles = new HashMap[Time,String] - - // TODO(Haoyuan): This is for the performance test. - @transient - val rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[String]] - - override def compute(validTime: Time): Option[RDD[String]] = { - generatedFiles.get(validTime) match { - case Some(file) => - logInfo("Reading from file " + file + " for time " + validTime) - // Some(ssc.sc.textFile(file).asInstanceOf[RDD[String]]) - // The following line is for HDFS performance test. Sould comment out the above line. - Some(rdd) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - generatedFiles += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } -} - -class NetworkInputRDS[T: ClassManifest]( - val networkInputName: String, - val addresses: Array[InetSocketAddress], - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[T](networkInputName, batchDuration, ssc) { - - - // TODO(Haoyuan): This is for the performance test. - @transient var rdd: RDD[T] = null - - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Running initial count to cache fake RDD") - rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[T]] - val fakeCacheLevel = System.getProperty("spark.fake.cache", "") - if (fakeCacheLevel == "MEMORY_ONLY_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel != "") { - logError("Invalid fake cache level: " + fakeCacheLevel) - System.exit(1) - } - rdd.count() - } - - @transient val references = new HashMap[Time,String] - - override def compute(validTime: Time): Option[RDD[T]] = { - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Returning fake RDD at " + validTime) - return Some(rdd) - } - references.get(validTime) match { - case Some(reference) => - if (reference.startsWith("file") || reference.startsWith("hdfs")) { - logInfo("Reading from file " + reference + " for time " + validTime) - Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) - } else { - logInfo("Getting from BlockManager " + reference + " for time " + validTime) - Some(new BlockRDD(ssc.sc, Array(reference))) - } - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } -} - - -class TestInputRDS( - val testInputName: String, - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[String](testInputName, batchDuration, ssc) { - - @transient val references = new HashMap[Time,Array[String]] - - override def compute(validTime: Time): Option[RDD[String]] = { - references.get(validTime) match { - case Some(reference) => - Some(new BlockRDD[String](ssc.sc, reference)) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.asInstanceOf[Array[String]])) - } -} - - -class MappedRDS[T: ClassManifest, U: ClassManifest] ( - parent: RDS[T], - mapFunc: T => U) -extends RDS[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.map[U](mapFunc)) - } -} - - -class FlatMappedRDS[T: ClassManifest, U: ClassManifest]( - parent: RDS[T], - flatMapFunc: T => Traversable[U]) -extends RDS[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) - } -} - - -class FilteredRDS[T: ClassManifest](parent: RDS[T], filterFunc: T => Boolean) -extends RDS[T](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - parent.getOrCompute(validTime).map(_.filter(filterFunc)) - } -} - -class MapPartitionedRDS[T: ClassManifest, U: ClassManifest]( - parent: RDS[T], - mapPartFunc: Iterator[T] => Iterator[U]) -extends RDS[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc)) - } -} - -class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Array[T]]] = { - parent.getOrCompute(validTime).map(_.glom()) - } -} - - -class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - parent: RDS[(K,V)], - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - numPartitions: Int) - extends RDS [(K,C)] (parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K,C)]] = { - parent.getOrCompute(validTime) match { - case Some(rdd) => - val newrdd = { - if (numPartitions > 0) { - rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, numPartitions) - } else { - rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner) - } - } - Some(newrdd) - case None => None - } - } -} - - -class UnifiedRDS[T: ClassManifest](parents: Array[RDS[T]]) -extends RDS[T](parents(0).ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different SparkStreamContexts") - } - - if (parents.map(_.slideTime).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideTime: Time = parents(0).slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { - case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) - if (rdds.size > 0) { - Some(new UnionRDD(ssc.sc, rdds)) - } else { - None - } - } -} - - -class PerElementForEachRDS[T: ClassManifest] ( - parent: RDS[T], - foreachFunc: T => Unit) -extends RDS[Unit](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - val sparkJobFunc = { - (iterator: Iterator[T]) => iterator.foreach(foreachFunc) - } - ssc.sc.runJob(rdd, sparkJobFunc) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} - - -class PerRDDForEachRDS[T: ClassManifest] ( - parent: RDS[T], - foreachFunc: (RDD[T], Time) => Unit) -extends RDS[Unit](parent.ssc) { - - def this(parent: RDS[T], altForeachFunc: (RDD[T]) => Unit) = - this(parent, (rdd: RDD[T], time: Time) => altForeachFunc(rdd)) - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - foreachFunc(rdd, time) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} diff --git a/streaming/src/main/scala/spark/stream/ReducedWindowedRDS.scala b/streaming/src/main/scala/spark/stream/ReducedWindowedRDS.scala deleted file mode 100644 index d47654ccb9..0000000000 --- a/streaming/src/main/scala/spark/stream/ReducedWindowedRDS.scala +++ /dev/null @@ -1,218 +0,0 @@ -package spark.stream - -import spark.stream.SparkStreamContext._ - -import spark.RDD -import spark.UnionRDD -import spark.CoGroupedRDD -import spark.HashPartitioner -import spark.SparkContext._ -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer - -class ReducedWindowedRDS[K: ClassManifest, V: ClassManifest]( - parent: RDS[(K, V)], - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - _windowTime: Time, - _slideTime: Time, - numPartitions: Int) -extends RDS[(K,V)](parent.ssc) { - - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of ReducedWindowedRDS (" + _slideTime + ") " + - "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") - - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of ReducedWindowedRDS (" + _slideTime + ") " + - "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") - - val reducedRDS = parent.reduceByKey(reduceFunc, numPartitions) - val allowPartialWindows = true - //reducedRDS.persist(StorageLevel.MEMORY_ONLY_DESER_2) - - override def dependencies = List(reducedRDS) - - def windowTime: Time = _windowTime - - override def slideTime: Time = _slideTime - - override def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): RDS[(K,V)] = { - super.persist(storageLevel, checkpointLevel, checkpointInterval) - reducedRDS.persist(storageLevel, checkpointLevel, checkpointInterval) - } - - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - - - // Notation: - // _____________________________ - // | previous window _________|___________________ - // |___________________| current window | --------------> Time - // |_____________________________| - // - // |________ _________| |________ _________| - // | | - // V V - // old time steps new time steps - // - def getAdjustedWindow(endTime: Time, windowTime: Time): Interval = { - val beginTime = - if (allowPartialWindows && endTime - windowTime < parent.zeroTime) { - parent.zeroTime - } else { - endTime - windowTime - } - Interval(beginTime, endTime) - } - - val currentTime = validTime.copy - val currentWindow = getAdjustedWindow(currentTime, windowTime) - val previousWindow = getAdjustedWindow(currentTime - slideTime, windowTime) - - logInfo("Current window = " + currentWindow) - logInfo("Previous window = " + previousWindow) - logInfo("Parent.zeroTime = " + parent.zeroTime) - - if (allowPartialWindows) { - if (currentTime - slideTime == parent.zeroTime) { - reducedRDS.getOrCompute(currentTime) match { - case Some(rdd) => return Some(rdd) - case None => throw new Exception("Could not get first reduced RDD for time " + currentTime) - } - } - } else { - if (previousWindow.beginTime < parent.zeroTime) { - if (currentWindow.beginTime < parent.zeroTime) { - return None - } else { - // If this is the first feasible window, then generate reduced value in the naive manner - val reducedRDDs = new ArrayBuffer[RDD[(K, V)]]() - var t = currentWindow.endTime - while (t > currentWindow.beginTime) { - reducedRDS.getOrCompute(t) match { - case Some(rdd) => reducedRDDs += rdd - case None => throw new Exception("Could not get reduced RDD for time " + t) - } - t -= reducedRDS.slideTime - } - if (reducedRDDs.size == 0) { - throw new Exception("Could not generate the first RDD for time " + validTime) - } - return Some(new UnionRDD(ssc.sc, reducedRDDs).reduceByKey(reduceFunc, numPartitions)) - } - } - } - - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = getOrCompute(previousWindow.endTime) match { - case Some(rdd) => rdd.asInstanceOf[RDD[(_, _)]] - case None => throw new Exception("Could not get previous RDD for time " + previousWindow.endTime) - } - - val oldRDDs = new ArrayBuffer[RDD[(_, _)]]() - val newRDDs = new ArrayBuffer[RDD[(_, _)]]() - - // Get the RDDs of the reduced values in "old time steps" - var t = currentWindow.beginTime - while (t > previousWindow.beginTime) { - reducedRDS.getOrCompute(t) match { - case Some(rdd) => oldRDDs += rdd.asInstanceOf[RDD[(_, _)]] - case None => throw new Exception("Could not get old reduced RDD for time " + t) - } - t -= reducedRDS.slideTime - } - - // Get the RDDs of the reduced values in "new time steps" - t = currentWindow.endTime - while (t > previousWindow.endTime) { - reducedRDS.getOrCompute(t) match { - case Some(rdd) => newRDDs += rdd.asInstanceOf[RDD[(_, _)]] - case None => throw new Exception("Could not get new reduced RDD for time " + t) - } - t -= reducedRDS.slideTime - } - - val partitioner = new HashPartitioner(numPartitions) - val allRDDs = new ArrayBuffer[RDD[(_, _)]]() - allRDDs += previousWindowRDD - allRDDs ++= oldRDDs - allRDDs ++= newRDDs - - - val numOldRDDs = oldRDDs.size - val numNewRDDs = newRDDs.size - logInfo("Generated numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) - logInfo("Generating CoGroupedRDD with " + allRDDs.size + " RDDs") - val newRDD = new CoGroupedRDD[K](allRDDs.toSeq, partitioner).asInstanceOf[RDD[(K,Seq[Seq[V]])]].map(x => { - val (key, value) = x - logDebug("value.size = " + value.size + ", numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) - if (value.size != 1 + numOldRDDs + numNewRDDs) { - throw new Exception("Number of groups not odd!") - } - - // old values = reduced values of the "old time steps" that are eliminated from current window - // new values = reduced values of the "new time steps" that are introduced to the current window - // previous value = reduced value of the previous window - - /*val numOldValues = (value.size - 1) / 2*/ - // Getting reduced values "old time steps" - val oldValues = - (0 until numOldRDDs).map(i => value(1 + i)).filter(_.size > 0).map(x => x(0)) - // Getting reduced values "new time steps" - val newValues = - (0 until numNewRDDs).map(i => value(1 + numOldRDDs + i)).filter(_.size > 0).map(x => x(0)) - - // If reduced value for the key does not exist in previous window, it should not exist in "old time steps" - if (value(0).size == 0 && oldValues.size != 0) { - throw new Exception("Unexpected: Key exists in old reduced values but not in previous reduced values") - } - - // For the key, at least one of "old time steps", "new time steps" and previous window should have reduced values - if (value(0).size == 0 && oldValues.size == 0 && newValues.size == 0) { - throw new Exception("Unexpected: Key does not exist in any of old, new, or previour reduced values") - } - - // Logic to generate the final reduced value for current window: - // - // If previous window did not have reduced value for the key - // Then, return reduced value of "new time steps" as the final value - // Else, reduced value exists in previous window - // If "old" time steps did not have reduced value for the key - // Then, reduce previous window's reduced value with that of "new time steps" for final value - // Else, reduced values exists in "old time steps" - // If "new values" did not have reduced value for the key - // Then, inverse-reduce "old values" from previous window's reduced value for final value - // Else, all 3 values exist, combine all of them together - // - logDebug("# old values = " + oldValues.size + ", # new values = " + newValues) - val finalValue = { - if (value(0).size == 0) { - newValues.reduce(reduceFunc) - } else { - val prevValue = value(0)(0) - logDebug("prev value = " + prevValue) - if (oldValues.size == 0) { - // assuming newValue.size > 0 (all 3 cannot be zero, as checked earlier) - val temp = newValues.reduce(reduceFunc) - reduceFunc(prevValue, temp) - } else if (newValues.size == 0) { - invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) - } else { - val tempValue = invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) - reduceFunc(tempValue, newValues.reduce(reduceFunc)) - } - } - } - (key, finalValue) - }) - //newRDD.persist(StorageLevel.MEMORY_ONLY_DESER_2) - Some(newRDD) - } -} - - diff --git a/streaming/src/main/scala/spark/stream/Scheduler.scala b/streaming/src/main/scala/spark/stream/Scheduler.scala deleted file mode 100644 index 38946fef11..0000000000 --- a/streaming/src/main/scala/spark/stream/Scheduler.scala +++ /dev/null @@ -1,181 +0,0 @@ -package spark.stream - -import spark.SparkEnv -import spark.Logging - -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.ArrayBuffer - -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ -import akka.util.duration._ - -sealed trait SchedulerMessage -case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage -case class Test extends SchedulerMessage - -class Scheduler( - ssc: SparkStreamContext, - inputRDSs: Array[InputRDS[_]], - outputRDSs: Array[RDS[_]]) -extends Actor with Logging { - - class InputState (inputNames: Array[String]) { - val inputsLeft = new HashSet[String]() - inputsLeft ++= inputNames - - val startTime = System.currentTimeMillis - - def delay() = System.currentTimeMillis - startTime - - def addGeneratedInput(inputName: String) = inputsLeft -= inputName - - def areAllInputsGenerated() = (inputsLeft.size == 0) - - override def toString(): String = { - val left = if (inputsLeft.size == 0) "" else inputsLeft.reduceLeft(_ + ", " + _) - return "Inputs left = [ " + left + " ]" - } - } - - - initLogging() - - val inputNames = inputRDSs.map(_.inputName).toArray - val inputStates = new HashMap[Interval, InputState]() - val currentJobs = System.getProperty("spark.stream.currentJobs", "1").toInt - val jobManager = new JobManager2(ssc, currentJobs) - - // TODO(Haoyuan): The following line is for performance test only. - var cnt: Int = System.getProperty("spark.stream.fake.cnt", "60").toInt - var lastInterval: Interval = null - - - /*remote.register("SparkStreamScheduler", actorOf[Scheduler])*/ - logInfo("Registered actor on port ") - - /*jobManager.start()*/ - startStreamReceivers() - - def receive = { - case InputGenerated(inputName, interval, reference) => { - addGeneratedInput(inputName, interval, reference) - } - case Test() => logInfo("TEST PASSED") - } - - def addGeneratedInput(inputName: String, interval: Interval, reference: AnyRef = null) { - logInfo("Input " + inputName + " generated for interval " + interval) - inputStates.get(interval) match { - case None => inputStates.put(interval, new InputState(inputNames)) - case _ => - } - inputStates(interval).addGeneratedInput(inputName) - - inputRDSs.filter(_.inputName == inputName).foreach(inputRDS => { - inputRDS.setReference(interval.endTime, reference) - if (inputRDS.isInstanceOf[TestInputRDS]) { - TestInputBlockTracker.addBlocks(interval.endTime, reference) - } - } - ) - - def getNextInterval(): Option[Interval] = { - logDebug("Last interval is " + lastInterval) - val readyIntervals = inputStates.filter(_._2.areAllInputsGenerated).keys - /*inputState.foreach(println) */ - logDebug("InputState has " + inputStates.size + " intervals, " + readyIntervals.size + " ready intervals") - return readyIntervals.find(lastInterval == null || _.beginTime == lastInterval.endTime) - } - - var nextInterval = getNextInterval() - var count = 0 - while(nextInterval.isDefined) { - val inputState = inputStates.get(nextInterval.get).get - generateRDDsForInterval(nextInterval.get) - logInfo("Skew delay for " + nextInterval.get.endTime + " is " + (inputState.delay / 1000.0) + " s") - inputStates.remove(nextInterval.get) - lastInterval = nextInterval.get - nextInterval = getNextInterval() - count += 1 - /*if (nextInterval.size == 0 && inputState.size > 0) { - logDebug("Next interval not ready, pending intervals " + inputState.size) - }*/ - } - logDebug("RDDs generated for " + count + " intervals") - - /* - if (inputState(interval).areAllInputsGenerated) { - generateRDDsForInterval(interval) - lastInterval = interval - inputState.remove(interval) - } else { - logInfo("All inputs not generated for interval " + interval) - } - */ - } - - def generateRDDsForInterval (interval: Interval) { - logInfo("Generating RDDs for interval " + interval) - outputRDSs.foreach(outputRDS => { - if (!outputRDS.isInitialized) outputRDS.initialize(interval) - outputRDS.generateJob(interval.endTime) match { - case Some(job) => submitJob(job) - case None => - } - } - ) - // TODO(Haoyuan): This comment is for performance test only. - if (System.getProperty("spark.fake", "false") == "true" || System.getProperty("spark.stream.fake", "false") == "true") { - cnt -= 1 - if (cnt <= 0) { - logInfo("My time is up! " + cnt) - System.exit(1) - } - } - } - - def submitJob(job: Job) { - logInfo("Submitting " + job + " to JobManager") - /*jobManager ! RunJob(job)*/ - jobManager.runJob(job) - } - - def startStreamReceivers() { - val testStreamReceiverNames = new ArrayBuffer[(String, Long)]() - inputRDSs.foreach (inputRDS => { - inputRDS match { - case fileInputRDS: FileInputRDS => { - val fileStreamReceiver = new FileStreamReceiver( - fileInputRDS.inputName, - fileInputRDS.directory, - fileInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds) - fileStreamReceiver.start() - } - case networkInputRDS: NetworkInputRDS[_] => { - val networkStreamReceiver = new NetworkStreamReceiver( - networkInputRDS.inputName, - networkInputRDS.batchDuration, - 0, - ssc, - if (ssc.tempDir == null) null else ssc.tempDir.toString) - networkStreamReceiver.start() - } - case testInputRDS: TestInputRDS => { - testStreamReceiverNames += - ((testInputRDS.inputName, testInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds)) - } - } - }) - if (testStreamReceiverNames.size > 0) { - /*val testStreamCoordinator = new TestStreamCoordinator(testStreamReceiverNames.toArray)*/ - /*testStreamCoordinator.start()*/ - val actor = ssc.actorSystem.actorOf( - Props(new TestStreamCoordinator(testStreamReceiverNames.toArray)), - name = "TestStreamCoordinator") - } - } -} - diff --git a/streaming/src/main/scala/spark/stream/SenGeneratorForPerformanceTest.scala b/streaming/src/main/scala/spark/stream/SenGeneratorForPerformanceTest.scala deleted file mode 100644 index 74fd54072f..0000000000 --- a/streaming/src/main/scala/spark/stream/SenGeneratorForPerformanceTest.scala +++ /dev/null @@ -1,78 +0,0 @@ -package spark.stream - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - -/*import akka.actor.Actor._*/ -/*import akka.actor.ActorRef*/ - - -object SenGeneratorForPerformanceTest { - - def printUsage () { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val inputManagerIP = args(0) - val inputManagerPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = { - if (args.length > 3) args(3).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - try { - /*val inputManager = remote.actorFor("InputReceiver-Sentences",*/ - /* inputManagerIP, inputManagerPort)*/ - val inputManager = select(Node(inputManagerIP, inputManagerPort), Symbol("InputReceiver-Sentences")) - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - println ("Sending " + sentencesPerSecond + " sentences per second to " + inputManagerIP + ":" + inputManagerPort) - var lastPrintTime = System.currentTimeMillis() - var count = 0 - - while (true) { - /*if (!inputManager.tryTell (lines (random.nextInt (lines.length))))*/ - /*throw new Exception ("disconnected")*/ -// inputManager ! lines (random.nextInt (lines.length)) - for (i <- 0 to sentencesPerSecond) inputManager ! lines (0) - println(System.currentTimeMillis / 1000 + " s") -/* count += 1 - - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - - Thread.sleep (sleepBetweenSentences.toLong) -*/ - val currentMs = System.currentTimeMillis / 1000; - Thread.sleep ((currentMs * 1000 + 1000) - System.currentTimeMillis) - } - } catch { - case e: Exception => - /*Thread.sleep (1000)*/ - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/stream/SenderReceiverTest.scala b/streaming/src/main/scala/spark/stream/SenderReceiverTest.scala deleted file mode 100644 index 69879b621c..0000000000 --- a/streaming/src/main/scala/spark/stream/SenderReceiverTest.scala +++ /dev/null @@ -1,63 +0,0 @@ -package spark.stream -import java.net.{Socket, ServerSocket} -import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} - -object Receiver { - def main(args: Array[String]) { - val port = args(0).toInt - val lsocket = new ServerSocket(port) - println("Listening on port " + port ) - while(true) { - val socket = lsocket.accept() - (new Thread() { - override def run() { - val buffer = new Array[Byte](100000) - var count = 0 - val time = System.currentTimeMillis - try { - val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) - var loop = true - var string: String = null - while((string = is.readUTF) != null) { - count += 28 - } - } catch { - case e: Exception => e.printStackTrace - } - val timeTaken = System.currentTimeMillis - time - val tput = (count / 1024.0) / (timeTaken / 1000.0) - println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") - } - }).start() - } - } - -} - -object Sender { - - def main(args: Array[String]) { - try { - val host = args(0) - val port = args(1).toInt - val size = args(2).toInt - - val byteStream = new ByteArrayOutputStream() - val stringDataStream = new DataOutputStream(byteStream) - (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) - val bytes = byteStream.toByteArray() - println("Generated array of " + bytes.length + " bytes") - - /*val bytes = new Array[Byte](size)*/ - val socket = new Socket(host, port) - val os = socket.getOutputStream - os.write(bytes) - os.flush - socket.close() - - } catch { - case e: Exception => e.printStackTrace - } - } -} - diff --git a/streaming/src/main/scala/spark/stream/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/stream/SentenceFileGenerator.scala deleted file mode 100644 index 9aa441d9bb..0000000000 --- a/streaming/src/main/scala/spark/stream/SentenceFileGenerator.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.stream - -import spark._ - -import scala.collection.mutable.ArrayBuffer -import scala.util.Random -import scala.io.Source - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -object SentenceFileGenerator { - - def printUsage () { - println ("Usage: SentenceFileGenerator <# partitions> []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 4) { - printUsage - } - - val master = args(0) - val fs = new Path(args(1)).getFileSystem(new Configuration()) - val targetDirectory = new Path(args(1)).makeQualified(fs) - val numPartitions = args(2).toInt - val sentenceFile = args(3) - val sentencesPerSecond = { - if (args.length > 4) args(4).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n").toArray - source.close () - println("Read " + lines.length + " lines from file " + sentenceFile) - - val sentences = { - val buffer = ArrayBuffer[String]() - val random = new Random() - var i = 0 - while (i < sentencesPerSecond) { - buffer += lines(random.nextInt(lines.length)) - i += 1 - } - buffer.toArray - } - println("Generated " + sentences.length + " sentences") - - val sc = new SparkContext(master, "SentenceFileGenerator") - val sentencesRDD = sc.parallelize(sentences, numPartitions) - - val tempDirectory = new Path(targetDirectory, "_tmp") - - fs.mkdirs(targetDirectory) - fs.mkdirs(tempDirectory) - - var saveTimeMillis = System.currentTimeMillis - try { - while (true) { - val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) - val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) - println("Writing to file " + newDir) - sentencesRDD.saveAsTextFile(tmpNewDir.toString) - fs.rename(tmpNewDir, newDir) - saveTimeMillis += 1000 - val sleepTimeMillis = { - val currentTimeMillis = System.currentTimeMillis - if (saveTimeMillis < currentTimeMillis) { - 0 - } else { - saveTimeMillis - currentTimeMillis - } - } - println("Sleeping for " + sleepTimeMillis + " ms") - Thread.sleep(sleepTimeMillis) - } - } catch { - case e: Exception => - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/stream/SentenceGenerator.scala b/streaming/src/main/scala/spark/stream/SentenceGenerator.scala deleted file mode 100644 index ef66e66047..0000000000 --- a/streaming/src/main/scala/spark/stream/SentenceGenerator.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.stream - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - - -object SentenceGenerator { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - - try { - var lastPrintTime = System.currentTimeMillis() - var count = 0 - while(true) { - streamReceiver ! lines(random.nextInt(lines.length)) - count += 1 - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - Thread.sleep(sleepBetweenSentences.toLong) - } - } catch { - case e: Exception => - } - } - - def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - try { - val numSentences = if (sentencesPerSecond <= 0) { - lines.length - } else { - sentencesPerSecond - } - var nextSendingTime = System.currentTimeMillis() - val pingInterval = if (System.getenv("INTERVAL") != null) { - System.getenv("INTERVAL").toInt - } else { - 2000 - } - while(true) { - (0 until numSentences).foreach(i => { - streamReceiver ! lines(i % lines.length) - }) - println ("Sent " + numSentences + " sentences") - nextSendingTime += pingInterval - val sleepTime = nextSendingTime - System.currentTimeMillis - if (sleepTime > 0) { - println ("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } - } - } catch { - case e: Exception => - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val generateRandomly = false - - val streamReceiverIP = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 - val sentenceInputName = if (args.length > 4) args(4) else "Sentences" - - println("Sending " + sentencesPerSecond + " sentences per second to " + - streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - val streamReceiver = select( - Node(streamReceiverIP, streamReceiverPort), - Symbol("NetworkStreamReceiver-" + sentenceInputName)) - if (generateRandomly) { - generateRandomSentences(lines, sentencesPerSecond, streamReceiver) - } else { - generateSameSentences(lines, sentencesPerSecond, streamReceiver) - } - } -} - - - diff --git a/streaming/src/main/scala/spark/stream/ShuffleTest.scala b/streaming/src/main/scala/spark/stream/ShuffleTest.scala deleted file mode 100644 index 5ad56f6777..0000000000 --- a/streaming/src/main/scala/spark/stream/ShuffleTest.scala +++ /dev/null @@ -1,22 +0,0 @@ -package spark.stream -import spark.SparkContext -import SparkContext._ - -object ShuffleTest { - def main(args: Array[String]) { - - if (args.length < 1) { - println ("Usage: ShuffleTest ") - System.exit(1) - } - - val sc = new spark.SparkContext(args(0), "ShuffleTest") - val rdd = sc.parallelize(1 to 1000, 500).cache - - def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } - - time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } - System.exit(0) - } -} - diff --git a/streaming/src/main/scala/spark/stream/SimpleWordCount.scala b/streaming/src/main/scala/spark/stream/SimpleWordCount.scala deleted file mode 100644 index c53fe35f44..0000000000 --- a/streaming/src/main/scala/spark/stream/SimpleWordCount.scala +++ /dev/null @@ -1,30 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1) - counts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/stream/SimpleWordCount2.scala b/streaming/src/main/scala/spark/stream/SimpleWordCount2.scala deleted file mode 100644 index 1a2c67cd4d..0000000000 --- a/streaming/src/main/scala/spark/stream/SimpleWordCount2.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.stream - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - /*words.foreachRDD(_.countByValue())*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/stream/SimpleWordCount2_Special.scala b/streaming/src/main/scala/spark/stream/SimpleWordCount2_Special.scala deleted file mode 100644 index 9003a5dbb3..0000000000 --- a/streaming/src/main/scala/spark/stream/SimpleWordCount2_Special.scala +++ /dev/null @@ -1,83 +0,0 @@ -package spark.stream - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.collection.JavaConversions.mapAsScalaMap -import scala.util.Sorting -import java.lang.{Long => JLong} - -object SimpleWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 400)).toArray - ) - - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - /*val words = sentences.flatMap(_.split(" "))*/ - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10)*/ - val counts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/stream/SparkStreamContext.scala b/streaming/src/main/scala/spark/stream/SparkStreamContext.scala deleted file mode 100644 index 0e65196e46..0000000000 --- a/streaming/src/main/scala/spark/stream/SparkStreamContext.scala +++ /dev/null @@ -1,105 +0,0 @@ -package spark.stream - -import spark.SparkContext -import spark.SparkEnv -import spark.Utils -import spark.Logging - -import scala.collection.mutable.ArrayBuffer - -import java.net.InetSocketAddress -import java.io.IOException -import java.util.UUID - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - -import akka.actor._ -import akka.actor.Actor -import akka.util.duration._ - -class SparkStreamContext ( - master: String, - frameworkName: String, - val sparkHome: String = null, - val jars: Seq[String] = Nil) - extends Logging { - - initLogging() - - val sc = new SparkContext(master, frameworkName, sparkHome, jars) - val env = SparkEnv.get - val actorSystem = env.actorSystem - - @transient val inputRDSs = new ArrayBuffer[InputRDS[_]]() - @transient val outputRDSs = new ArrayBuffer[RDS[_]]() - - var tempDirRoot: String = null - var tempDir: Path = null - - def readNetworkStream[T: ClassManifest]( - name: String, - addresses: Array[InetSocketAddress], - batchDuration: Time): RDS[T] = { - - val inputRDS = new NetworkInputRDS[T](name, addresses, batchDuration, this) - inputRDSs += inputRDS - inputRDS - } - - def readNetworkStream[T: ClassManifest]( - name: String, - addresses: Array[String], - batchDuration: Long): RDS[T] = { - - def stringToInetSocketAddress (str: String): InetSocketAddress = { - val parts = str.split(":") - if (parts.length != 2) { - throw new IllegalArgumentException ("Address format error") - } - new InetSocketAddress(parts(0), parts(1).toInt) - } - - readNetworkStream( - name, - addresses.map(stringToInetSocketAddress).toArray, - LongTime(batchDuration)) - } - - def readFileStream(name: String, directory: String): RDS[String] = { - val path = new Path(directory) - val fs = path.getFileSystem(new Configuration()) - val qualPath = path.makeQualified(fs) - val inputRDS = new FileInputRDS(name, qualPath.toString, this) - inputRDSs += inputRDS - inputRDS - } - - def readTestStream(name: String, batchDuration: Long): RDS[String] = { - val inputRDS = new TestInputRDS(name, LongTime(batchDuration), this) - inputRDSs += inputRDS - inputRDS - } - - def registerOutputStream (outputRDS: RDS[_]) { - outputRDSs += outputRDS - } - - def setTempDir(dir: String) { - tempDirRoot = dir - } - - def run () { - val ctxt = this - val actor = actorSystem.actorOf( - Props(new Scheduler(ctxt, inputRDSs.toArray, outputRDSs.toArray)), - name = "SparkStreamScheduler") - logInfo("Registered actor") - actorSystem.awaitTermination() - } -} - -object SparkStreamContext { - implicit def rdsToPairRdsFunctions [K: ClassManifest, V: ClassManifest] (rds: RDS[(K,V)]) = - new PairRDSFunctions (rds) -} diff --git a/streaming/src/main/scala/spark/stream/TestGenerator.scala b/streaming/src/main/scala/spark/stream/TestGenerator.scala deleted file mode 100644 index 738ce17452..0000000000 --- a/streaming/src/main/scala/spark/stream/TestGenerator.scala +++ /dev/null @@ -1,107 +0,0 @@ -package spark.stream - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - - -object TestGenerator { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - /* - def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - - try { - var lastPrintTime = System.currentTimeMillis() - var count = 0 - while(true) { - streamReceiver ! lines(random.nextInt(lines.length)) - count += 1 - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - Thread.sleep(sleepBetweenSentences.toLong) - } - } catch { - case e: Exception => - } - }*/ - - def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - try { - val numSentences = if (sentencesPerSecond <= 0) { - lines.length - } else { - sentencesPerSecond - } - val sentences = lines.take(numSentences).toArray - - var nextSendingTime = System.currentTimeMillis() - val sendAsArray = true - while(true) { - if (sendAsArray) { - println("Sending as array") - streamReceiver !? sentences - } else { - println("Sending individually") - sentences.foreach(sentence => { - streamReceiver !? sentence - }) - } - println ("Sent " + numSentences + " sentences in " + (System.currentTimeMillis - nextSendingTime) + " ms") - nextSendingTime += 1000 - val sleepTime = nextSendingTime - System.currentTimeMillis - if (sleepTime > 0) { - println ("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } - } - } catch { - case e: Exception => - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val generateRandomly = false - - val streamReceiverIP = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 - val sentenceInputName = if (args.length > 4) args(4) else "Sentences" - - println("Sending " + sentencesPerSecond + " sentences per second to " + - streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - val streamReceiver = select( - Node(streamReceiverIP, streamReceiverPort), - Symbol("NetworkStreamReceiver-" + sentenceInputName)) - if (generateRandomly) { - /*generateRandomSentences(lines, sentencesPerSecond, streamReceiver)*/ - } else { - generateSameSentences(lines, sentencesPerSecond, streamReceiver) - } - } -} - - - diff --git a/streaming/src/main/scala/spark/stream/TestGenerator2.scala b/streaming/src/main/scala/spark/stream/TestGenerator2.scala deleted file mode 100644 index ceb4730e72..0000000000 --- a/streaming/src/main/scala/spark/stream/TestGenerator2.scala +++ /dev/null @@ -1,119 +0,0 @@ -package spark.stream - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.io.{DataOutputStream, ByteArrayOutputStream, DataInputStream} -import java.net.Socket - -object TestGenerator2 { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def sendSentences(streamReceiverHost: String, streamReceiverPort: Int, numSentences: Int, bytes: Array[Byte], intervalTime: Long){ - try { - println("Connecting to " + streamReceiverHost + ":" + streamReceiverPort) - val socket = new Socket(streamReceiverHost, streamReceiverPort) - - println("Sending " + numSentences+ " sentences / " + (bytes.length / 1024.0 / 1024.0) + " MB per " + intervalTime + " ms to " + streamReceiverHost + ":" + streamReceiverPort ) - val currentTime = System.currentTimeMillis - var targetTime = (currentTime / intervalTime + 1).toLong * intervalTime - Thread.sleep(targetTime - currentTime) - - while(true) { - val startTime = System.currentTimeMillis() - println("Sending at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") - val socketOutputStream = socket.getOutputStream - val parts = 10 - (0 until parts).foreach(i => { - val partStartTime = System.currentTimeMillis - - val offset = (i * bytes.length / parts).toInt - val len = math.min(((i + 1) * bytes.length / parts).toInt - offset, bytes.length) - socketOutputStream.write(bytes, offset, len) - socketOutputStream.flush() - val partFinishTime = System.currentTimeMillis - println("Sending part " + i + " of " + len + " bytes took " + (partFinishTime - partStartTime) + " ms") - val sleepTime = math.max(0, 1000 / parts - (partFinishTime - partStartTime) - 1) - Thread.sleep(sleepTime) - }) - - socketOutputStream.flush() - /*val socketInputStream = new DataInputStream(socket.getInputStream)*/ - /*val reply = socketInputStream.readUTF()*/ - val finishTime = System.currentTimeMillis() - println ("Sent " + bytes.length + " bytes in " + (finishTime - startTime) + " ms for interval [" + targetTime + ", " + (targetTime + intervalTime) + "]") - /*println("Received = " + reply)*/ - targetTime = targetTime + intervalTime - val sleepTime = (targetTime - finishTime) + 10 - if (sleepTime > 0) { - println("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } else { - println("############################") - println("###### Skipping sleep ######") - println("############################") - } - } - } catch { - case e: Exception => println(e) - } - println("Stopped sending") - } - - def main(args: Array[String]) { - if (args.length < 4) { - printUsage - } - - val streamReceiverHost = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val intervalTime = args(3).toLong - val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 - - println("Reading the file " + sentenceFile) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close() - - val numSentences = if (sentencesPerInterval <= 0) { - lines.length - } else { - sentencesPerInterval - } - - println("Generating sentences") - val sentences: Array[String] = if (numSentences <= lines.length) { - lines.take(numSentences).toArray - } else { - (0 until numSentences).map(i => lines(i % lines.length)).toArray - } - - println("Converting to byte array") - val byteStream = new ByteArrayOutputStream() - val stringDataStream = new DataOutputStream(byteStream) - /*stringDataStream.writeInt(sentences.size)*/ - sentences.foreach(stringDataStream.writeUTF) - val bytes = byteStream.toByteArray() - stringDataStream.close() - println("Generated array of " + bytes.length + " bytes") - - /*while(true) { */ - sendSentences(streamReceiverHost, streamReceiverPort, numSentences, bytes, intervalTime) - /*println("Sleeping for 5 seconds")*/ - /*Thread.sleep(5000)*/ - /*System.gc()*/ - /*}*/ - } -} - - - diff --git a/streaming/src/main/scala/spark/stream/TestGenerator4.scala b/streaming/src/main/scala/spark/stream/TestGenerator4.scala deleted file mode 100644 index edeb969d7c..0000000000 --- a/streaming/src/main/scala/spark/stream/TestGenerator4.scala +++ /dev/null @@ -1,244 +0,0 @@ -package spark.stream - -import spark.Logging - -import scala.util.Random -import scala.io.Source -import scala.collection.mutable.{ArrayBuffer, Queue} - -import java.net._ -import java.io._ -import java.nio._ -import java.nio.charset._ -import java.nio.channels._ - -import it.unimi.dsi.fastutil.io._ - -class TestGenerator4(targetHost: String, targetPort: Int, sentenceFile: String, intervalDuration: Long, sentencesPerInterval: Int) -extends Logging { - - class SendingConnectionHandler(host: String, port: Int, generator: TestGenerator4) - extends ConnectionHandler(host, port, true) { - - val buffers = new ArrayBuffer[ByteBuffer] - val newBuffers = new Queue[ByteBuffer] - var activeKey: SelectionKey = null - - def send(buffer: ByteBuffer) { - logDebug("Sending: " + buffer) - newBuffers.synchronized { - newBuffers.enqueue(buffer) - } - selector.wakeup() - buffer.synchronized { - buffer.wait() - } - } - - override def ready(key: SelectionKey) { - logDebug("Ready") - activeKey = key - val channel = key.channel.asInstanceOf[SocketChannel] - channel.register(selector, SelectionKey.OP_WRITE) - generator.startSending() - } - - override def preSelect() { - newBuffers.synchronized { - while(!newBuffers.isEmpty) { - val buffer = newBuffers.dequeue - buffers += buffer - logDebug("Added: " + buffer) - changeInterest(activeKey, SelectionKey.OP_WRITE) - } - } - } - - override def write(key: SelectionKey) { - try { - /*while(true) {*/ - val channel = key.channel.asInstanceOf[SocketChannel] - if (buffers.size > 0) { - val buffer = buffers(0) - val newBuffer = buffer.slice() - newBuffer.limit(math.min(newBuffer.remaining, 32768)) - val bytesWritten = channel.write(newBuffer) - buffer.position(buffer.position + bytesWritten) - if (bytesWritten == 0) return - if (buffer.remaining == 0) { - buffers -= buffer - buffer.synchronized { - buffer.notify() - } - } - /*changeInterest(key, SelectionKey.OP_WRITE)*/ - } else { - changeInterest(key, 0) - } - /*}*/ - } catch { - case e: IOException => { - if (e.toString.contains("pipe") || e.toString.contains("reset")) { - logError("Connection broken") - } else { - logError("Connection error", e) - } - close(key) - } - } - } - - override def close(key: SelectionKey) { - buffers.clear() - super.close(key) - } - } - - initLogging() - - val connectionHandler = new SendingConnectionHandler(targetHost, targetPort, this) - var sendingThread: Thread = null - var sendCount = 0 - val sendBatches = 5 - - def run() { - logInfo("Connection handler started") - connectionHandler.start() - connectionHandler.join() - if (sendingThread != null && !sendingThread.isInterrupted) { - sendingThread.interrupt - } - logInfo("Connection handler stopped") - } - - def startSending() { - sendingThread = new Thread() { - override def run() { - logInfo("STARTING TO SEND") - sendSentences() - logInfo("SENDING STOPPED AFTER " + sendCount) - connectionHandler.interrupt() - } - } - sendingThread.start() - } - - def stopSending() { - sendingThread.interrupt() - } - - def sendSentences() { - logInfo("Reading the file " + sentenceFile) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close() - - val numSentences = if (sentencesPerInterval <= 0) { - lines.length - } else { - sentencesPerInterval - } - - logInfo("Generating sentence buffer") - val sentences: Array[String] = if (numSentences <= lines.length) { - lines.take(numSentences).toArray - } else { - (0 until numSentences).map(i => lines(i % lines.length)).toArray - } - - /* - val sentences: Array[String] = if (numSentences <= lines.length) { - lines.take((numSentences / sendBatches).toInt).toArray - } else { - (0 until (numSentences/sendBatches)).map(i => lines(i % lines.length)).toArray - }*/ - - - val serializer = new spark.KryoSerializer().newInstance() - val byteStream = new FastByteArrayOutputStream(100 * 1024 * 1024) - serializer.serializeStream(byteStream).writeAll(sentences.toIterator.asInstanceOf[Iterator[Any]]).close() - byteStream.trim() - val sentenceBuffer = ByteBuffer.wrap(byteStream.array) - - logInfo("Sending " + numSentences+ " sentences / " + sentenceBuffer.limit + " bytes per " + intervalDuration + " ms to " + targetHost + ":" + targetPort ) - val currentTime = System.currentTimeMillis - var targetTime = (currentTime / intervalDuration + 1).toLong * intervalDuration - Thread.sleep(targetTime - currentTime) - - val totalBytes = sentenceBuffer.limit - - while(true) { - val batchesInCurrentInterval = sendBatches // if (sendCount < 10) 1 else sendBatches - - val startTime = System.currentTimeMillis() - logDebug("Sending # " + sendCount + " at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") - - (0 until batchesInCurrentInterval).foreach(i => { - try { - val position = (i * totalBytes / sendBatches).toInt - val limit = if (i == sendBatches - 1) { - totalBytes - } else { - ((i + 1) * totalBytes / sendBatches).toInt - 1 - } - - val partStartTime = System.currentTimeMillis - sentenceBuffer.limit(limit) - connectionHandler.send(sentenceBuffer) - val partFinishTime = System.currentTimeMillis - val sleepTime = math.max(0, intervalDuration / sendBatches - (partFinishTime - partStartTime) - 1) - Thread.sleep(sleepTime) - - } catch { - case ie: InterruptedException => return - case e: Exception => e.printStackTrace() - } - }) - sentenceBuffer.rewind() - - val finishTime = System.currentTimeMillis() - /*logInfo ("Sent " + sentenceBuffer.limit + " bytes in " + (finishTime - startTime) + " ms")*/ - targetTime = targetTime + intervalDuration //+ (if (sendCount < 3) 1000 else 0) - - val sleepTime = (targetTime - finishTime) + 20 - if (sleepTime > 0) { - logInfo("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } else { - logInfo("###### Skipping sleep ######") - } - if (Thread.currentThread.isInterrupted) { - return - } - sendCount += 1 - } - } -} - -object TestGenerator4 { - def printUsage { - println("Usage: TestGenerator4 []") - System.exit(0) - } - - def main(args: Array[String]) { - println("GENERATOR STARTED") - if (args.length < 4) { - printUsage - } - - - val streamReceiverHost = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val intervalDuration = args(3).toLong - val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 - - while(true) { - val generator = new TestGenerator4(streamReceiverHost, streamReceiverPort, sentenceFile, intervalDuration, sentencesPerInterval) - generator.run() - Thread.sleep(2000) - } - println("GENERATOR STOPPED") - } -} diff --git a/streaming/src/main/scala/spark/stream/TestInputBlockTracker.scala b/streaming/src/main/scala/spark/stream/TestInputBlockTracker.scala deleted file mode 100644 index da3b964407..0000000000 --- a/streaming/src/main/scala/spark/stream/TestInputBlockTracker.scala +++ /dev/null @@ -1,42 +0,0 @@ -package spark.stream -import spark.Logging -import scala.collection.mutable.{ArrayBuffer, HashMap} - -object TestInputBlockTracker extends Logging { - initLogging() - val allBlockIds = new HashMap[Time, ArrayBuffer[String]]() - - def addBlocks(intervalEndTime: Time, reference: AnyRef) { - allBlockIds.getOrElseUpdate(intervalEndTime, new ArrayBuffer[String]()) ++= reference.asInstanceOf[Array[String]] - } - - def setEndTime(intervalEndTime: Time) { - try { - val endTime = System.currentTimeMillis - allBlockIds.get(intervalEndTime) match { - case Some(blockIds) => { - val numBlocks = blockIds.size - var totalDelay = 0d - blockIds.foreach(blockId => { - val inputTime = getInputTime(blockId) - val delay = (endTime - inputTime) / 1000.0 - totalDelay += delay - logInfo("End-to-end delay for block " + blockId + " is " + delay + " s") - }) - logInfo("Average end-to-end delay for time " + intervalEndTime + " is " + (totalDelay / numBlocks) + " s") - allBlockIds -= intervalEndTime - } - case None => throw new Exception("Unexpected") - } - } catch { - case e: Exception => logError(e.toString) - } - } - - def getInputTime(blockId: String): Long = { - val parts = blockId.split("-") - /*logInfo(blockId + " -> " + parts(4)) */ - parts(4).toLong - } -} - diff --git a/streaming/src/main/scala/spark/stream/TestStreamCoordinator.scala b/streaming/src/main/scala/spark/stream/TestStreamCoordinator.scala deleted file mode 100644 index add166fbd9..0000000000 --- a/streaming/src/main/scala/spark/stream/TestStreamCoordinator.scala +++ /dev/null @@ -1,38 +0,0 @@ -package spark.stream - -import spark.Logging - -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ - -sealed trait TestStreamCoordinatorMessage -case class GetStreamDetails extends TestStreamCoordinatorMessage -case class GotStreamDetails(name: String, duration: Long) extends TestStreamCoordinatorMessage -case class TestStarted extends TestStreamCoordinatorMessage - -class TestStreamCoordinator(streamDetails: Array[(String, Long)]) extends Actor with Logging { - - var index = 0 - - initLogging() - - logInfo("Created") - - def receive = { - case TestStarted => { - sender ! "OK" - } - - case GetStreamDetails => { - val streamDetail = if (index >= streamDetails.length) null else streamDetails(index) - sender ! GotStreamDetails(streamDetail._1, streamDetail._2) - index += 1 - if (streamDetail != null) { - logInfo("Allocated " + streamDetail._1 + " (" + index + "/" + streamDetails.length + ")" ) - } - } - } - -} - diff --git a/streaming/src/main/scala/spark/stream/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/stream/TestStreamReceiver3.scala deleted file mode 100644 index 9cc342040b..0000000000 --- a/streaming/src/main/scala/spark/stream/TestStreamReceiver3.scala +++ /dev/null @@ -1,420 +0,0 @@ -package spark.stream - -import spark._ -import spark.storage._ -import spark.util.AkkaUtils - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} - -import akka.actor._ -import akka.actor.Actor -import akka.dispatch._ -import akka.pattern.ask -import akka.util.duration._ - -import java.io.DataInputStream -import java.io.BufferedInputStream -import java.net.Socket -import java.net.ServerSocket -import java.util.LinkedHashMap - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -import spark.Utils - - -class TestStreamReceiver3(actorSystem: ActorSystem, blockManager: BlockManager) -extends Thread with Logging { - - - class DataHandler( - inputName: String, - longIntervalDuration: LongTime, - shortIntervalDuration: LongTime, - blockManager: BlockManager - ) - extends Logging { - - class Block(var id: String, var shortInterval: Interval) { - val data = ArrayBuffer[String]() - var pushed = false - def longInterval = getLongInterval(shortInterval) - def empty() = (data.size == 0) - def += (str: String) = (data += str) - override def toString() = "Block " + id - } - - class Bucket(val longInterval: Interval) { - val blocks = new ArrayBuffer[Block]() - var filled = false - def += (block: Block) = blocks += block - def empty() = (blocks.size == 0) - def ready() = (filled && !blocks.exists(! _.pushed)) - def blockIds() = blocks.map(_.id).toArray - override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" - } - - initLogging() - - val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds - val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds - - var currentBlock: Block = null - var currentBucket: Bucket = null - - val blocksForPushing = new Queue[Block]() - val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] - - val blockUpdatingThread = new Thread() { override def run() { keepUpdatingCurrentBlock() } } - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - def start() { - blockUpdatingThread.start() - blockPushingThread.start() - } - - def += (data: String) = addData(data) - - def addData(data: String) { - if (currentBlock == null) { - updateCurrentBlock() - } - currentBlock.synchronized { - currentBlock += data - } - } - - def getShortInterval(time: Time): Interval = { - val intervalBegin = time.floor(shortIntervalDuration) - Interval(intervalBegin, intervalBegin + shortIntervalDuration) - } - - def getLongInterval(shortInterval: Interval): Interval = { - val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) - Interval(intervalBegin, intervalBegin + longIntervalDuration) - } - - def updateCurrentBlock() { - /*logInfo("Updating current block")*/ - val currentTime: LongTime = LongTime(System.currentTimeMillis) - val shortInterval = getShortInterval(currentTime) - val longInterval = getLongInterval(shortInterval) - - def createBlock(reuseCurrentBlock: Boolean = false) { - val newBlockId = inputName + "-" + longInterval.toFormattedString + "-" + currentBucket.blocks.size - if (!reuseCurrentBlock) { - val newBlock = new Block(newBlockId, shortInterval) - /*logInfo("Created " + currentBlock)*/ - currentBlock = newBlock - } else { - currentBlock.shortInterval = shortInterval - currentBlock.id = newBlockId - } - } - - def createBucket() { - val newBucket = new Bucket(longInterval) - buckets += ((longInterval, newBucket)) - currentBucket = newBucket - /*logInfo("Created " + currentBucket + ", " + buckets.size + " buckets")*/ - } - - if (currentBlock == null || currentBucket == null) { - createBucket() - currentBucket.synchronized { - createBlock() - } - return - } - - currentBlock.synchronized { - var reuseCurrentBlock = false - - if (shortInterval != currentBlock.shortInterval) { - if (!currentBlock.empty) { - blocksForPushing.synchronized { - blocksForPushing += currentBlock - blocksForPushing.notifyAll() - } - } - - currentBucket.synchronized { - if (currentBlock.empty) { - reuseCurrentBlock = true - } else { - currentBucket += currentBlock - } - - if (longInterval != currentBucket.longInterval) { - currentBucket.filled = true - if (currentBucket.ready) { - currentBucket.notifyAll() - } - createBucket() - } - } - - createBlock(reuseCurrentBlock) - } - } - } - - def pushBlock(block: Block) { - try{ - if (blockManager != null) { - logInfo("Pushing block") - val startTime = System.currentTimeMillis - - val bytes = blockManager.dataSerialize(block.data.toIterator) - val finishTime = System.currentTimeMillis - logInfo(block + " serialization delay is " + (finishTime - startTime) / 1000.0 + " s") - - blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_2) - /*blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ - /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY_DESER)*/ - /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY)*/ - val finishTime1 = System.currentTimeMillis - logInfo(block + " put delay is " + (finishTime1 - startTime) / 1000.0 + " s") - } else { - logWarning(block + " not put as block manager is null") - } - } catch { - case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) - } - } - - def getBucket(longInterval: Interval): Option[Bucket] = { - buckets.get(longInterval) - } - - def clearBucket(longInterval: Interval) { - buckets.remove(longInterval) - } - - def keepUpdatingCurrentBlock() { - logInfo("Thread to update current block started") - while(true) { - updateCurrentBlock() - val currentTimeMillis = System.currentTimeMillis - val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * - shortIntervalDurationMillis - currentTimeMillis + 1 - Thread.sleep(sleepTimeMillis) - } - } - - def keepPushingBlocks() { - var loop = true - logInfo("Thread to push blocks started") - while(loop) { - val block = blocksForPushing.synchronized { - if (blocksForPushing.size == 0) { - blocksForPushing.wait() - } - blocksForPushing.dequeue - } - pushBlock(block) - block.pushed = true - block.data.clear() - - val bucket = buckets(block.longInterval) - bucket.synchronized { - if (bucket.ready) { - bucket.notifyAll() - } - } - } - } - } - - - class ConnectionListener(port: Int, dataHandler: DataHandler) - extends Thread with Logging { - initLogging() - override def run { - try { - val listener = new ServerSocket(port) - logInfo("Listening on port " + port) - while (true) { - new ConnectionHandler(listener.accept(), dataHandler).start(); - } - listener.close() - } catch { - case e: Exception => logError("", e); - } - } - } - - class ConnectionHandler(socket: Socket, dataHandler: DataHandler) extends Thread with Logging { - initLogging() - override def run { - logInfo("New connection from " + socket.getInetAddress() + ":" + socket.getPort) - val bytes = new Array[Byte](100 * 1024 * 1024) - try { - - val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, 1024 * 1024)) - /*val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream))*/ - var str: String = null - str = inputStream.readUTF - while(str != null) { - dataHandler += str - str = inputStream.readUTF() - } - - /* - var loop = true - while(loop) { - val numRead = inputStream.read(bytes) - if (numRead < 0) { - loop = false - } - inbox += ((LongTime(SystemTime.currentTimeMillis), "test")) - }*/ - - inputStream.close() - } catch { - case e => logError("Error receiving data", e) - } - socket.close() - } - } - - initLogging() - - val masterHost = System.getProperty("spark.master.host") - val masterPort = System.getProperty("spark.master.port").toInt - - val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) - val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") - val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") - - logInfo("Getting stream details from master " + masterHost + ":" + masterPort) - - val timeout = 50 millis - - var started = false - while (!started) { - askActor[String](testStreamCoordinator, TestStarted) match { - case Some(str) => { - started = true - logInfo("TestStreamCoordinator started") - } - case None => { - logInfo("TestStreamCoordinator not started yet") - Thread.sleep(200) - } - } - } - - val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { - case Some(details) => details - case None => throw new Exception("Could not get stream details") - } - logInfo("Stream details received: " + streamDetails) - - val inputName = streamDetails.name - val intervalDurationMillis = streamDetails.duration - val intervalDuration = LongTime(intervalDurationMillis) - - val dataHandler = new DataHandler( - inputName, - intervalDuration, - LongTime(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), - blockManager) - - val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) - - // Send a message to an actor and return an option with its reply, or None if this times out - def askActor[T](actor: ActorRef, message: Any): Option[T] = { - try { - val future = actor.ask(message)(timeout) - return Some(Await.result(future, timeout).asInstanceOf[T]) - } catch { - case e: Exception => - logInfo("Error communicating with " + actor, e) - return None - } - } - - override def run() { - connListener.start() - dataHandler.start() - - var interval = Interval.currentInterval(intervalDuration) - var dataStarted = false - - while(true) { - waitFor(interval.endTime) - logInfo("Woken up at " + System.currentTimeMillis + " for " + interval) - dataHandler.getBucket(interval) match { - case Some(bucket) => { - logInfo("Found " + bucket + " for " + interval) - bucket.synchronized { - if (!bucket.ready) { - logInfo("Waiting for " + bucket) - bucket.wait() - logInfo("Wait over for " + bucket) - } - if (dataStarted || !bucket.empty) { - logInfo("Notifying " + bucket) - notifyScheduler(interval, bucket.blockIds) - dataStarted = true - } - bucket.blocks.clear() - dataHandler.clearBucket(interval) - } - } - case None => { - logInfo("Found none for " + interval) - if (dataStarted) { - logInfo("Notifying none") - notifyScheduler(interval, Array[String]()) - } - } - } - interval = interval.next - } - } - - def waitFor(time: Time) { - val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds - if (currentTimeMillis < targetTimeMillis) { - val sleepTime = (targetTimeMillis - currentTimeMillis) - Thread.sleep(sleepTime + 1) - } - } - - def notifyScheduler(interval: Interval, blockIds: Array[String]) { - try { - sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime.asInstanceOf[LongTime] - val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 - logInfo("Pushing delay for " + time + " is " + delay + " s") - } catch { - case _ => logError("Exception notifying scheduler at interval " + interval) - } - } -} - -object TestStreamReceiver3 { - - val PORT = 9999 - val SHORT_INTERVAL_MILLIS = 100 - - def main(args: Array[String]) { - System.setProperty("spark.master.host", Utils.localHostName) - System.setProperty("spark.master.port", "7078") - val details = Array(("Sentences", 2000L)) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) - actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") - new TestStreamReceiver3(actorSystem, null).start() - } -} - - - diff --git a/streaming/src/main/scala/spark/stream/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/stream/TestStreamReceiver4.scala deleted file mode 100644 index e7bef75391..0000000000 --- a/streaming/src/main/scala/spark/stream/TestStreamReceiver4.scala +++ /dev/null @@ -1,373 +0,0 @@ -package spark.stream - -import spark._ -import spark.storage._ -import spark.util.AkkaUtils - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} - -import java.io._ -import java.nio._ -import java.nio.charset._ -import java.nio.channels._ -import java.util.concurrent.Executors - -import akka.actor._ -import akka.actor.Actor -import akka.dispatch._ -import akka.pattern.ask -import akka.util.duration._ - -class TestStreamReceiver4(actorSystem: ActorSystem, blockManager: BlockManager) -extends Thread with Logging { - - class DataHandler( - inputName: String, - longIntervalDuration: LongTime, - shortIntervalDuration: LongTime, - blockManager: BlockManager - ) - extends Logging { - - class Block(val id: String, val shortInterval: Interval, val buffer: ByteBuffer) { - var pushed = false - def longInterval = getLongInterval(shortInterval) - override def toString() = "Block " + id - } - - class Bucket(val longInterval: Interval) { - val blocks = new ArrayBuffer[Block]() - var filled = false - def += (block: Block) = blocks += block - def empty() = (blocks.size == 0) - def ready() = (filled && !blocks.exists(! _.pushed)) - def blockIds() = blocks.map(_.id).toArray - override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" - } - - initLogging() - - val syncOnLastShortInterval = true - - val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds - val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds - - val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) - var currentShortInterval = Interval.currentInterval(shortIntervalDuration) - - val blocksForPushing = new Queue[Block]() - val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] - - val bufferProcessingThread = new Thread() { override def run() { keepProcessingBuffers() } } - val blockPushingExecutor = Executors.newFixedThreadPool(5) - - - def start() { - buffer.clear() - if (buffer.remaining == 0) { - throw new Exception("Buffer initialization error") - } - bufferProcessingThread.start() - } - - def readDataToBuffer(func: ByteBuffer => Int): Int = { - buffer.synchronized { - if (buffer.remaining == 0) { - logInfo("Received first data for interval " + currentShortInterval) - } - func(buffer) - } - } - - def getLongInterval(shortInterval: Interval): Interval = { - val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) - Interval(intervalBegin, intervalBegin + longIntervalDuration) - } - - def processBuffer() { - - def readInt(buffer: ByteBuffer): Int = { - var offset = 0 - var result = 0 - while (offset < 32) { - val b = buffer.get() - result |= ((b & 0x7F) << offset) - if ((b & 0x80) == 0) { - return result - } - offset += 7 - } - throw new Exception("Malformed zigzag-encoded integer") - } - - val currentLongInterval = getLongInterval(currentShortInterval) - val startTime = System.currentTimeMillis - val newBuffer: ByteBuffer = buffer.synchronized { - buffer.flip() - if (buffer.remaining == 0) { - buffer.clear() - null - } else { - logDebug("Processing interval " + currentShortInterval + " with delay of " + (System.currentTimeMillis - startTime) + " ms") - val startTime1 = System.currentTimeMillis - var loop = true - var count = 0 - while(loop) { - buffer.mark() - try { - val len = readInt(buffer) - buffer.position(buffer.position + len) - count += 1 - } catch { - case e: Exception => { - buffer.reset() - loop = false - } - } - } - val bytesToCopy = buffer.position - val newBuf = ByteBuffer.allocate(bytesToCopy) - buffer.position(0) - newBuf.put(buffer.slice().limit(bytesToCopy).asInstanceOf[ByteBuffer]) - newBuf.flip() - buffer.position(bytesToCopy) - buffer.compact() - newBuf - } - } - - if (newBuffer != null) { - val bucket = buckets.getOrElseUpdate(currentLongInterval, new Bucket(currentLongInterval)) - bucket.synchronized { - val newBlockId = inputName + "-" + currentLongInterval.toFormattedString + "-" + currentShortInterval.toFormattedString - val newBlock = new Block(newBlockId, currentShortInterval, newBuffer) - if (syncOnLastShortInterval) { - bucket += newBlock - } - logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.asInstanceOf[LongTime].milliseconds) / 1000.0 + " s" ) - blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) - } - } - - val newShortInterval = Interval.currentInterval(shortIntervalDuration) - val newLongInterval = getLongInterval(newShortInterval) - - if (newLongInterval != currentLongInterval) { - buckets.get(currentLongInterval) match { - case Some(bucket) => { - bucket.synchronized { - bucket.filled = true - if (bucket.ready) { - bucket.notifyAll() - } - } - } - case None => - } - buckets += ((newLongInterval, new Bucket(newLongInterval))) - } - - currentShortInterval = newShortInterval - } - - def pushBlock(block: Block) { - try{ - if (blockManager != null) { - val startTime = System.currentTimeMillis - logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.asInstanceOf[LongTime].milliseconds) + " ms") - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ - blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY)*/ - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER)*/ - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ - val finishTime = System.currentTimeMillis - logInfo(block + " put delay is " + (finishTime - startTime) + " ms") - } else { - logWarning(block + " not put as block manager is null") - } - } catch { - case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) - } - } - - def getBucket(longInterval: Interval): Option[Bucket] = { - buckets.get(longInterval) - } - - def clearBucket(longInterval: Interval) { - buckets.remove(longInterval) - } - - def keepProcessingBuffers() { - logInfo("Thread to process buffers started") - while(true) { - processBuffer() - val currentTimeMillis = System.currentTimeMillis - val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * - shortIntervalDurationMillis - currentTimeMillis + 1 - Thread.sleep(sleepTimeMillis) - } - } - - def pushAndNotifyBlock(block: Block) { - pushBlock(block) - block.pushed = true - val bucket = if (syncOnLastShortInterval) { - buckets(block.longInterval) - } else { - var longInterval = block.longInterval - while(!buckets.contains(longInterval)) { - logWarning("Skipping bucket of " + longInterval + " for " + block) - longInterval = longInterval.next - } - val chosenBucket = buckets(longInterval) - logDebug("Choosing bucket of " + longInterval + " for " + block) - chosenBucket += block - chosenBucket - } - - bucket.synchronized { - if (bucket.ready) { - bucket.notifyAll() - } - } - - } - } - - - class ReceivingConnectionHandler(host: String, port: Int, dataHandler: DataHandler) - extends ConnectionHandler(host, port, false) { - - override def ready(key: SelectionKey) { - changeInterest(key, SelectionKey.OP_READ) - } - - override def read(key: SelectionKey) { - try { - val channel = key.channel.asInstanceOf[SocketChannel] - val bytesRead = dataHandler.readDataToBuffer(channel.read) - if (bytesRead < 0) { - close(key) - } - } catch { - case e: IOException => { - logError("Error reading", e) - close(key) - } - } - } - } - - initLogging() - - val masterHost = System.getProperty("spark.master.host", "localhost") - val masterPort = System.getProperty("spark.master.port", "7078").toInt - - val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) - val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") - val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") - - logInfo("Getting stream details from master " + masterHost + ":" + masterPort) - - val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { - case Some(details) => details - case None => throw new Exception("Could not get stream details") - } - logInfo("Stream details received: " + streamDetails) - - val inputName = streamDetails.name - val intervalDurationMillis = streamDetails.duration - val intervalDuration = Milliseconds(intervalDurationMillis) - val shortIntervalDuration = Milliseconds(System.getProperty("spark.stream.shortinterval", "500").toInt) - - val dataHandler = new DataHandler(inputName, intervalDuration, shortIntervalDuration, blockManager) - val connectionHandler = new ReceivingConnectionHandler("localhost", 9999, dataHandler) - - val timeout = 100 millis - - // Send a message to an actor and return an option with its reply, or None if this times out - def askActor[T](actor: ActorRef, message: Any): Option[T] = { - try { - val future = actor.ask(message)(timeout) - return Some(Await.result(future, timeout).asInstanceOf[T]) - } catch { - case e: Exception => - logInfo("Error communicating with " + actor, e) - return None - } - } - - override def run() { - connectionHandler.start() - dataHandler.start() - - var interval = Interval.currentInterval(intervalDuration) - var dataStarted = false - - - while(true) { - waitFor(interval.endTime) - /*logInfo("Woken up at " + System.currentTimeMillis + " for " + interval)*/ - dataHandler.getBucket(interval) match { - case Some(bucket) => { - logDebug("Found " + bucket + " for " + interval) - bucket.synchronized { - if (!bucket.ready) { - logDebug("Waiting for " + bucket) - bucket.wait() - logDebug("Wait over for " + bucket) - } - if (dataStarted || !bucket.empty) { - logDebug("Notifying " + bucket) - notifyScheduler(interval, bucket.blockIds) - dataStarted = true - } - bucket.blocks.clear() - dataHandler.clearBucket(interval) - } - } - case None => { - logDebug("Found none for " + interval) - if (dataStarted) { - logDebug("Notifying none") - notifyScheduler(interval, Array[String]()) - } - } - } - interval = interval.next - } - } - - def waitFor(time: Time) { - val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds - if (currentTimeMillis < targetTimeMillis) { - val sleepTime = (targetTimeMillis - currentTimeMillis) - Thread.sleep(sleepTime + 1) - } - } - - def notifyScheduler(interval: Interval, blockIds: Array[String]) { - try { - sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime.asInstanceOf[LongTime] - val delay = (System.currentTimeMillis - time.milliseconds) - logInfo("Notification delay for " + time + " is " + delay + " ms") - } catch { - case e: Exception => logError("Exception notifying scheduler at interval " + interval + ": " + e) - } - } -} - - -object TestStreamReceiver4 { - def main(args: Array[String]) { - val details = Array(("Sentences", 2000L)) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) - actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") - new TestStreamReceiver4(actorSystem, null).start() - } -} diff --git a/streaming/src/main/scala/spark/stream/Time.scala b/streaming/src/main/scala/spark/stream/Time.scala deleted file mode 100644 index 25369dfee5..0000000000 --- a/streaming/src/main/scala/spark/stream/Time.scala +++ /dev/null @@ -1,85 +0,0 @@ -package spark.stream - -abstract case class Time { - - // basic operations that must be overridden - def copy(): Time - def zero: Time - def < (that: Time): Boolean - def += (that: Time): Time - def -= (that: Time): Time - def floor(that: Time): Time - def isMultipleOf(that: Time): Boolean - - // derived operations composed of basic operations - def + (that: Time) = this.copy() += that - def - (that: Time) = this.copy() -= that - def * (times: Int) = { - var count = 0 - var result = this.copy() - while (count < times) { - result += this - count += 1 - } - result - } - def <= (that: Time) = (this < that || this == that) - def > (that: Time) = !(this <= that) - def >= (that: Time) = !(this < that) - def isZero = (this == zero) - def toFormattedString = toString -} - -object Time { - def Milliseconds(milliseconds: Long) = LongTime(milliseconds) - - def zero = LongTime(0) -} - -case class LongTime(var milliseconds: Long) extends Time { - - override def copy() = LongTime(this.milliseconds) - - override def zero = LongTime(0) - - override def < (that: Time): Boolean = - (this.milliseconds < that.asInstanceOf[LongTime].milliseconds) - - override def += (that: Time): Time = { - this.milliseconds += that.asInstanceOf[LongTime].milliseconds - this - } - - override def -= (that: Time): Time = { - this.milliseconds -= that.asInstanceOf[LongTime].milliseconds - this - } - - override def floor(that: Time): Time = { - val t = that.asInstanceOf[LongTime].milliseconds - val m = this.milliseconds / t - LongTime(m.toLong * t) - } - - override def isMultipleOf(that: Time): Boolean = - (this.milliseconds % that.asInstanceOf[LongTime].milliseconds == 0) - - override def isZero = (this.milliseconds == 0) - - override def toString = (milliseconds.toString + "ms") - - override def toFormattedString = milliseconds.toString -} - -object Milliseconds { - def apply(milliseconds: Long) = LongTime(milliseconds) -} - -object Seconds { - def apply(seconds: Long) = LongTime(seconds * 1000) -} - -object Minutes { - def apply(minutes: Long) = LongTime(minutes * 60000) -} - diff --git a/streaming/src/main/scala/spark/stream/TopContentCount.scala b/streaming/src/main/scala/spark/stream/TopContentCount.scala deleted file mode 100644 index a8cca4e793..0000000000 --- a/streaming/src/main/scala/spark/stream/TopContentCount.scala +++ /dev/null @@ -1,97 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopContentCount { - - case class Event(val country: String, val content: String) - - object Event { - def create(string: String): Event = { - val parts = string.split(":") - new Event(parts(0), parts(1)) - } - } - - def main(args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopContentCount") - val sc = ssc.sc - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - - val numEventStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - val eventStrings = new UnifiedRDS( - (1 to numEventStreams).map(i => ssc.readTestStream("Events-" + i, 1000)).toArray - ) - - def parse(string: String) = { - val parts = string.split(":") - (parts(0), parts(1)) - } - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val events = eventStrings.map(x => parse(x)) - /*events.print*/ - - val parallelism = 8 - val counts_per_content_per_country = events - .map(x => (x, 1)) - .reduceByKey(_ + _) - /*.reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), parallelism)*/ - /*counts_per_content_per_country.print*/ - - /* - counts_per_content_per_country.persist( - StorageLevel.MEMORY_ONLY_DESER, - StorageLevel.MEMORY_ONLY_DESER_2, - Seconds(1) - )*/ - - val counts_per_country = counts_per_content_per_country - .map(x => (x._1._1, (x._1._2, x._2))) - .groupByKey() - counts_per_country.print - - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - val taken = array.take(k) - taken - } - - val k = 10 - val topKContents_per_country = counts_per_country - .map(x => (x._1, topK(x._2, k))) - .map(x => (x._1, x._2.map(_.toString).reduceLeft(_ + ", " + _))) - - topKContents_per_country.print - - ssc.run - } -} - - - diff --git a/streaming/src/main/scala/spark/stream/TopKWordCount2.scala b/streaming/src/main/scala/spark/stream/TopKWordCount2.scala deleted file mode 100644 index 7dd06dd5ee..0000000000 --- a/streaming/src/main/scala/spark/stream/TopKWordCount2.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.stream - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopKWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - - def topK(data: Iterator[(String, Int)], k: Int): Iterator[(String, Int)] = { - val taken = new Array[(String, Int)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, Int) = null - var swap: (String, Int) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/stream/TopKWordCount2_Special.scala b/streaming/src/main/scala/spark/stream/TopKWordCount2_Special.scala deleted file mode 100644 index e9f3f914ae..0000000000 --- a/streaming/src/main/scala/spark/stream/TopKWordCount2_Special.scala +++ /dev/null @@ -1,142 +0,0 @@ -package spark.stream - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object TopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopKWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - /*val words = sentences.flatMap(_.split(" "))*/ - - /*def add(v1: Int, v2: Int) = (v1 + v2) */ - /*def subtract(v1: Int, v2: Int) = (v1 - v2) */ - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val windowedCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Milliseconds(500), 10) - /*windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1))*/ - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 50 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/stream/WindowedRDS.scala b/streaming/src/main/scala/spark/stream/WindowedRDS.scala deleted file mode 100644 index a2e7966edb..0000000000 --- a/streaming/src/main/scala/spark/stream/WindowedRDS.scala +++ /dev/null @@ -1,68 +0,0 @@ -package spark.stream - -import spark.stream.SparkStreamContext._ - -import spark.RDD -import spark.UnionRDD -import spark.SparkContext._ - -import scala.collection.mutable.ArrayBuffer - -class WindowedRDS[T: ClassManifest]( - parent: RDS[T], - _windowTime: Time, - _slideTime: Time) - extends RDS[T](parent.ssc) { - - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of WindowedRDS (" + _slideTime + ") " + - "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") - - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of WindowedRDS (" + _slideTime + ") " + - "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") - - val allowPartialWindows = true - - override def dependencies = List(parent) - - def windowTime: Time = _windowTime - - override def slideTime: Time = _slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val parentRDDs = new ArrayBuffer[RDD[T]]() - val windowEndTime = validTime.copy() - val windowStartTime = if (allowPartialWindows && windowEndTime - windowTime < parent.zeroTime) { - parent.zeroTime - } else { - windowEndTime - windowTime - } - - logInfo("Window = " + windowStartTime + " - " + windowEndTime) - logInfo("Parent.zeroTime = " + parent.zeroTime) - - if (windowStartTime >= parent.zeroTime) { - // Walk back through time, from the 'windowEndTime' to 'windowStartTime' - // and get all parent RDDs from the parent RDS - var t = windowEndTime - while (t > windowStartTime) { - parent.getOrCompute(t) match { - case Some(rdd) => parentRDDs += rdd - case None => throw new Exception("Could not generate parent RDD for time " + t) - } - t -= parent.slideTime - } - } - - // Do a union of all parent RDDs to generate the new RDD - if (parentRDDs.size > 0) { - Some(new UnionRDD(ssc.sc, parentRDDs)) - } else { - None - } - } -} - - - diff --git a/streaming/src/main/scala/spark/stream/WordCount.scala b/streaming/src/main/scala/spark/stream/WordCount.scala deleted file mode 100644 index af825e46a8..0000000000 --- a/streaming/src/main/scala/spark/stream/WordCount.scala +++ /dev/null @@ -1,62 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - //val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - //val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(_ + _) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - parallelism) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/stream/WordCount1.scala b/streaming/src/main/scala/spark/stream/WordCount1.scala deleted file mode 100644 index 501062b18d..0000000000 --- a/streaming/src/main/scala/spark/stream/WordCount1.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount1 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/stream/WordCount2.scala b/streaming/src/main/scala/spark/stream/WordCount2.scala deleted file mode 100644 index 24324e891a..0000000000 --- a/streaming/src/main/scala/spark/stream/WordCount2.scala +++ /dev/null @@ -1,55 +0,0 @@ -package spark.stream - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object WordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 6) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/stream/WordCount2_Special.scala b/streaming/src/main/scala/spark/stream/WordCount2_Special.scala deleted file mode 100644 index c6b1aaa57e..0000000000 --- a/streaming/src/main/scala/spark/stream/WordCount2_Special.scala +++ /dev/null @@ -1,94 +0,0 @@ -package spark.stream - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - - -object WordCount2_ExtraFunctions { - - def add(v1: JLong, v2: JLong) = (v1 + v2) - - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } -} - -object WordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 40).foreach {i => - sc.parallelize(1 to 20000000, 1000) - .map(_ % 1331).map(_.toString) - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - val windowedCounts = sentences - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions) - .reduceByKeyAndWindow(WordCount2_ExtraFunctions.add _, WordCount2_ExtraFunctions.subtract _, Seconds(10), Milliseconds(500), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/stream/WordCount3.scala b/streaming/src/main/scala/spark/stream/WordCount3.scala deleted file mode 100644 index 455a8c9dbf..0000000000 --- a/streaming/src/main/scala/spark/stream/WordCount3.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCount3 { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - /*val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1)*/ - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/stream/WordCountEc2.scala b/streaming/src/main/scala/spark/stream/WordCountEc2.scala deleted file mode 100644 index 5b10026d7a..0000000000 --- a/streaming/src/main/scala/spark/stream/WordCountEc2.scala +++ /dev/null @@ -1,41 +0,0 @@ -package spark.stream - -import SparkStreamContext._ -import spark.SparkContext - -object WordCountEc2 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: SparkStreamContext ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val ssc = new SparkStreamContext(args(0), "Test") - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.foreach(println)*/ - - val words = sentences.flatMap(_.split(" ")) - /*words.foreach(println)*/ - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _) - /*counts.foreach(println)*/ - - counts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - /*counts.register*/ - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/stream/WordCountTrivialWindow.scala b/streaming/src/main/scala/spark/stream/WordCountTrivialWindow.scala deleted file mode 100644 index 5469df71e9..0000000000 --- a/streaming/src/main/scala/spark/stream/WordCountTrivialWindow.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCountTrivialWindow { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCountTrivialWindow") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1)*/ - /*counts.print*/ - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/stream/WordMax.scala b/streaming/src/main/scala/spark/stream/WordMax.scala deleted file mode 100644 index fc075e6d9d..0000000000 --- a/streaming/src/main/scala/spark/stream/WordMax.scala +++ /dev/null @@ -1,64 +0,0 @@ -package spark.stream - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordMax { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - def max(v1: Int, v2: Int) = (if (v1 > v2) v1 else v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER) - localCounts.persist(StorageLevel.MEMORY_ONLY_DESER_2) - val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(max _) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - // parallelism) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/BlockID.scala b/streaming/src/main/scala/spark/streaming/BlockID.scala new file mode 100644 index 0000000000..16aacfda18 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/BlockID.scala @@ -0,0 +1,20 @@ +package spark.streaming + +case class BlockID(sRds: String, sInterval: Interval, sPartition: Int) { + override def toString : String = ( + sRds + BlockID.sConnector + + sInterval.beginTime + BlockID.sConnector + + sInterval.endTime + BlockID.sConnector + + sPartition + ) +} + +object BlockID { + val sConnector = '-' + + def parse(name : String) = BlockID( + name.split(BlockID.sConnector)(0), + new Interval(name.split(BlockID.sConnector)(1).toLong, + name.split(BlockID.sConnector)(2).toLong), + name.split(BlockID.sConnector)(3).toInt) +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/ConnectionHandler.scala b/streaming/src/main/scala/spark/streaming/ConnectionHandler.scala new file mode 100644 index 0000000000..a4f454632f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ConnectionHandler.scala @@ -0,0 +1,157 @@ +package spark.streaming + +import spark.Logging + +import scala.collection.mutable.{ArrayBuffer, SynchronizedQueue} + +import java.net._ +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ +import java.nio.channels.spi._ + +abstract class ConnectionHandler(host: String, port: Int, connect: Boolean) +extends Thread with Logging { + + val selector = SelectorProvider.provider.openSelector() + val interestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + + initLogging() + + override def run() { + try { + if (connect) { + connect() + } else { + listen() + } + + var interrupted = false + while(!interrupted) { + + preSelect() + + while(!interestChangeRequests.isEmpty) { + val (key, ops) = interestChangeRequests.dequeue + val lastOps = key.interestOps() + key.interestOps(ops) + + def intToOpStr(op: Int): String = { + val opStrs = new ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed ops from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } + + selector.select() + interrupted = Thread.currentThread.isInterrupted + + val selectedKeys = selector.selectedKeys().iterator() + while (selectedKeys.hasNext) { + val key = selectedKeys.next.asInstanceOf[SelectionKey] + selectedKeys.remove() + if (key.isValid) { + if (key.isAcceptable) { + accept(key) + } else if (key.isConnectable) { + finishConnect(key) + } else if (key.isReadable) { + read(key) + } else if (key.isWritable) { + write(key) + } + } + } + } + } catch { + case e: Exception => { + logError("Error in select loop", e) + } + } + } + + def connect() { + val socketAddress = new InetSocketAddress(host, port) + val channel = SocketChannel.open() + channel.configureBlocking(false) + channel.socket.setReuseAddress(true) + channel.socket.setTcpNoDelay(true) + channel.connect(socketAddress) + channel.register(selector, SelectionKey.OP_CONNECT) + logInfo("Initiating connection to [" + socketAddress + "]") + } + + def listen() { + val channel = ServerSocketChannel.open() + channel.configureBlocking(false) + channel.socket.setReuseAddress(true) + channel.socket.setReceiveBufferSize(256 * 1024) + channel.socket.bind(new InetSocketAddress(port)) + channel.register(selector, SelectionKey.OP_ACCEPT) + logInfo("Listening on port " + port) + } + + def finishConnect(key: SelectionKey) { + try { + val channel = key.channel.asInstanceOf[SocketChannel] + val address = channel.socket.getRemoteSocketAddress + channel.finishConnect() + logInfo("Connected to [" + host + ":" + port + "]") + ready(key) + } catch { + case e: IOException => { + logError("Error finishing connect to " + host + ":" + port) + close(key) + } + } + } + + def accept(key: SelectionKey) { + try { + val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] + val channel = serverChannel.accept() + val address = channel.socket.getRemoteSocketAddress + channel.configureBlocking(false) + logInfo("Accepted connection from [" + address + "]") + ready(channel.register(selector, 0)) + } catch { + case e: IOException => { + logError("Error accepting connection", e) + } + } + } + + def changeInterest(key: SelectionKey, ops: Int) { + logTrace("Added request to change ops to " + ops) + interestChangeRequests += ((key, ops)) + } + + def ready(key: SelectionKey) + + def preSelect() { + } + + def read(key: SelectionKey) { + throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) + } + + def write(key: SelectionKey) { + throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) + } + + def close(key: SelectionKey) { + try { + key.channel.close() + key.cancel() + Thread.currentThread.interrupt + } catch { + case e: Exception => logError("Error closing connection", e) + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/DumbTopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/DumbTopKWordCount2_Special.scala new file mode 100644 index 0000000000..2ca72da79f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DumbTopKWordCount2_Special.scala @@ -0,0 +1,138 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.Queue + +import java.lang.{Long => JLong} + +object DumbTopKWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + + def add(v1: JLong, v2: JLong) = (v1 + v2) + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } + + + val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + wordCounts.persist(StorageLevel.MEMORY_ONLY) + val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) + + def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { + val taken = new Array[(String, JLong)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, JLong) = null + var swap: (String, JLong) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + /*println("count = " + count)*/ + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 10 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + + /* + windowedCounts.filter(_ == null).foreachRDD(rdd => { + val count = rdd.count + println("# of nulls = " + count) + })*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/DumbWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/DumbWordCount2_Special.scala new file mode 100644 index 0000000000..34e7edfda9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DumbWordCount2_Special.scala @@ -0,0 +1,92 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import java.lang.{Long => JLong} +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + +object DumbWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + def add(v1: JLong, v2: JLong) = (v1 + v2) + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + + map.toIterator + } + + val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + wordCounts.persist(StorageLevel.MEMORY_ONLY) + val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala b/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala new file mode 100644 index 0000000000..92c7cfe00c --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala @@ -0,0 +1,70 @@ +package spark.streaming + +import spark.Logging + +import scala.collection.mutable.HashSet +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +class FileStreamReceiver ( + inputName: String, + rootDirectory: String, + intervalDuration: Long) + extends Logging { + + val pollInterval = 100 + val sparkstreamScheduler = { + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + 1 + RemoteActor.select(Node(host, port), 'SparkStreamScheduler) + } + val directory = new Path(rootDirectory) + val fs = directory.getFileSystem(new Configuration()) + val files = new HashSet[String]() + var time: Long = 0 + + def start() { + fs.mkdirs(directory) + files ++= getFiles() + + actor { + logInfo("Monitoring directory - " + rootDirectory) + while(true) { + testFiles(getFiles()) + Thread.sleep(pollInterval) + } + } + } + + def getFiles(): Iterable[String] = { + fs.listStatus(directory).map(_.getPath.toString) + } + + def testFiles(fileList: Iterable[String]) { + fileList.foreach(file => { + if (!files.contains(file)) { + if (!file.endsWith("_tmp")) { + notifyFile(file) + } + files += file + } + }) + } + + def notifyFile(file: String) { + logInfo("Notifying file " + file) + time += intervalDuration + val interval = Interval(LongTime(time), LongTime(time + intervalDuration)) + sparkstreamScheduler ! InputGenerated(inputName, interval, file) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/GrepCount.scala b/streaming/src/main/scala/spark/streaming/GrepCount.scala new file mode 100644 index 0000000000..ec3e70f258 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/GrepCount.scala @@ -0,0 +1,39 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object GrepCount { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: GrepCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "GrepCount") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + //sentences.print + val matching = sentences.filter(_.contains("light")) + matching.foreachRDD(rdd => println(rdd.count)) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/GrepCount2.scala b/streaming/src/main/scala/spark/streaming/GrepCount2.scala new file mode 100644 index 0000000000..27ecced1c0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/GrepCount2.scala @@ -0,0 +1,113 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkEnv +import spark.SparkContext +import spark.storage.StorageLevel +import spark.network.Message +import spark.network.ConnectionManagerId + +import java.nio.ByteBuffer + +object GrepCount2 { + + def startSparkEnvs(sc: SparkContext) { + + val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) + sc.runJob(dummy, (_: Iterator[Int]) => {}) + + println("SparkEnvs started") + Thread.sleep(1000) + /*sc.runJob(sc.parallelize(0 to 1000, 100), (_: Iterator[Int]) => {})*/ + } + + def warmConnectionManagers(sc: SparkContext) { + val slaveConnManagerIds = sc.parallelize(0 to 100, 100).map( + i => SparkEnv.get.connectionManager.id).collect().distinct + println("\nSlave ConnectionManagerIds") + slaveConnManagerIds.foreach(println) + println + + Thread.sleep(1000) + val numSlaves = slaveConnManagerIds.size + val count = 3 + val size = 5 * 1024 * 1024 + val iterations = (500 * 1024 * 1024 / (numSlaves * size)).toInt + println("count = " + count + ", size = " + size + ", iterations = " + iterations) + + (0 until count).foreach(i => { + val resultStrs = sc.parallelize(0 until numSlaves, numSlaves).map(i => { + val connManager = SparkEnv.get.connectionManager + val thisConnManagerId = connManager.id + /*connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + println("Received [" + msg + "] from [" + id + "]") + None + })*/ + + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + val futures = (0 until iterations).map(i => { + slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + println("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") + connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) + }) + }).flatMap(x => x) + val results = futures.map(f => f()) + val finishTime = System.currentTimeMillis + + + val mb = size * results.size / 1024.0 / 1024.0 + val ms = finishTime - startTime + + val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" + println(resultStr) + System.gc() + resultStr + }).collect() + + println("---------------------") + println("Run " + i) + resultStrs.foreach(println) + println("---------------------") + }) + } + + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: GrepCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "GrepCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + /*startSparkEnvs(ssc.sc)*/ + warmConnectionManagers(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-"+i, 500)).toArray + ) + + val matching = sentences.filter(_.contains("light")) + matching.foreachRDD(rdd => println(rdd.count)) + + ssc.run + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/GrepCountApprox.scala b/streaming/src/main/scala/spark/streaming/GrepCountApprox.scala new file mode 100644 index 0000000000..f9674136fe --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/GrepCountApprox.scala @@ -0,0 +1,54 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object GrepCountApprox { + var inputFile : String = null + var hdfs : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 5) { + println ("Usage: GrepCountApprox ") + System.exit(1) + } + + hdfs = args(1) + inputFile = hdfs + args(2) + idealPartitions = args(3).toInt + val timeout = args(4).toLong + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "GrepCount") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + ssc.setTempDir(hdfs + "/tmp") + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + //sentences.print + val matching = sentences.filter(_.contains("light")) + var i = 0 + val startTime = System.currentTimeMillis + matching.foreachRDD { rdd => + val myNum = i + val result = rdd.countApprox(timeout) + val initialTime = (System.currentTimeMillis - startTime) / 1000.0 + printf("APPROX\t%.2f\t%d\tinitial\t%.1f\t%.1f\n", initialTime, myNum, result.initialValue.mean, + result.initialValue.high - result.initialValue.low) + result.onComplete { r => + val finalTime = (System.currentTimeMillis - startTime) / 1000.0 + printf("APPROX\t%.2f\t%d\tfinal\t%.1f\t0.0\t%.1f\n", finalTime, myNum, r.mean, finalTime - initialTime) + } + i += 1 + } + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/IdealPerformance.scala b/streaming/src/main/scala/spark/streaming/IdealPerformance.scala new file mode 100644 index 0000000000..303d4e7ae6 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/IdealPerformance.scala @@ -0,0 +1,36 @@ +package spark.streaming + +import scala.collection.mutable.Map + +object IdealPerformance { + val base: String = "The medium researcher counts around the pinched troop The empire breaks " + + "Matei Matei announces HY with a theorem " + + def main (args: Array[String]) { + val sentences: String = base * 100000 + + for (i <- 1 to 30) { + val start = System.nanoTime + + val words = sentences.split(" ") + + val pairs = words.map(word => (word, 1)) + + val counts = Map[String, Int]() + + println("Job " + i + " position A at " + (System.nanoTime - start) / 1e9) + + pairs.foreach((pair) => { + var t = counts.getOrElse(pair._1, 0) + counts(pair._1) = t + pair._2 + }) + println("Job " + i + " position B at " + (System.nanoTime - start) / 1e9) + + for ((word, count) <- counts) { + print(word + " " + count + "; ") + } + println + println("Job " + i + " finished in " + (System.nanoTime - start) / 1e9) + } + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala new file mode 100644 index 0000000000..a985f44ba1 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -0,0 +1,75 @@ +package spark.streaming + +case class Interval (val beginTime: Time, val endTime: Time) { + + def this(beginMs: Long, endMs: Long) = this(new LongTime(beginMs), new LongTime(endMs)) + + def duration(): Time = endTime - beginTime + + def += (time: Time) { + beginTime += time + endTime += time + this + } + + def + (time: Time): Interval = { + new Interval(beginTime + time, endTime + time) + } + + def < (that: Interval): Boolean = { + if (this.duration != that.duration) { + throw new Exception("Comparing two intervals with different durations [" + this + ", " + that + "]") + } + this.endTime < that.endTime + } + + def <= (that: Interval) = (this < that || this == that) + + def > (that: Interval) = !(this <= that) + + def >= (that: Interval) = !(this < that) + + def next(): Interval = { + this + (endTime - beginTime) + } + + def isZero() = (beginTime.isZero && endTime.isZero) + + def toFormattedString = beginTime.toFormattedString + "-" + endTime.toFormattedString + + override def toString = "[" + beginTime + ", " + endTime + "]" +} + +object Interval { + + /* + implicit def longTupleToInterval (longTuple: (Long, Long)) = + Interval(longTuple._1, longTuple._2) + + implicit def intTupleToInterval (intTuple: (Int, Int)) = + Interval(intTuple._1, intTuple._2) + + implicit def string2Interval (str: String): Interval = { + val parts = str.split(",") + if (parts.length == 1) + return Interval.zero + return Interval (parts(0).toInt, parts(1).toInt) + } + + def getInterval (timeMs: Long, intervalDurationMs: Long): Interval = { + val intervalBeginMs = timeMs / intervalDurationMs * intervalDurationMs + Interval(intervalBeginMs, intervalBeginMs + intervalDurationMs) + } + */ + + def zero() = new Interval (Time.zero, Time.zero) + + def currentInterval(intervalDuration: LongTime): Interval = { + val time = LongTime(System.currentTimeMillis) + val intervalBegin = time.floor(intervalDuration) + Interval(intervalBegin, intervalBegin + intervalDuration) + } + +} + + diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala new file mode 100644 index 0000000000..f7654dff79 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -0,0 +1,21 @@ +package spark.streaming + +class Job(val time: Time, func: () => _) { + val id = Job.getNewId() + + def run() { + func() + } + + override def toString = "SparkStream Job " + id + ":" + time +} + +object Job { + var lastId = 1 + + def getNewId() = synchronized { + lastId += 1 + lastId + } +} + diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala new file mode 100644 index 0000000000..45a3971643 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -0,0 +1,112 @@ +package spark.streaming + +import spark.SparkEnv +import spark.Logging + +import scala.collection.mutable.PriorityQueue +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ +import scala.actors.scheduler.ResizableThreadPoolScheduler +import scala.actors.scheduler.ForkJoinScheduler + +sealed trait JobManagerMessage +case class RunJob(job: Job) extends JobManagerMessage +case class JobCompleted(handlerId: Int) extends JobManagerMessage + +class JobHandler(ssc: SparkStreamContext, val id: Int) extends DaemonActor with Logging { + + var busy = false + + def act() { + loop { + receive { + case job: Job => { + SparkEnv.set(ssc.env) + try { + logInfo("Starting " + job) + job.run() + logInfo("Finished " + job) + if (job.time.isInstanceOf[LongTime]) { + val longTime = job.time.asInstanceOf[LongTime] + logInfo("Total pushing + skew + processing delay for " + longTime + " is " + + (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") + } + } catch { + case e: Exception => logError("SparkStream job failed", e) + } + busy = false + reply(JobCompleted(id)) + } + } + } + } +} + +class JobManager(ssc: SparkStreamContext, numThreads: Int = 2) extends DaemonActor with Logging { + + implicit private val jobOrdering = new Ordering[Job] { + override def compare(job1: Job, job2: Job): Int = { + if (job1.time < job2.time) { + return 1 + } else if (job2.time < job1.time) { + return -1 + } else { + return 0 + } + } + } + + private val jobs = new PriorityQueue[Job]() + private val handlers = (0 until numThreads).map(i => new JobHandler(ssc, i)) + + def act() { + handlers.foreach(_.start) + loop { + receive { + case RunJob(job) => { + jobs += job + logInfo("Job " + job + " submitted") + runJob() + } + case JobCompleted(handlerId) => { + runJob() + } + } + } + } + + def runJob(): Unit = { + logInfo("Attempting to allocate job ") + if (jobs.size > 0) { + handlers.find(!_.busy).foreach(handler => { + val job = jobs.dequeue + logInfo("Allocating job " + job + " to handler " + handler.id) + handler.busy = true + handler ! job + }) + } + } +} + +object JobManager { + def main(args: Array[String]) { + val ssc = new SparkStreamContext("local[4]", "JobManagerTest") + val jobManager = new JobManager(ssc) + jobManager.start() + + val t = System.currentTimeMillis + for (i <- 1 to 10) { + jobManager ! RunJob(new Job( + LongTime(i), + () => { + Thread.sleep(500) + println("Job " + i + " took " + (System.currentTimeMillis - t) + " ms") + } + )) + } + Thread.sleep(6000) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/JobManager2.scala b/streaming/src/main/scala/spark/streaming/JobManager2.scala new file mode 100644 index 0000000000..ce0154e19b --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/JobManager2.scala @@ -0,0 +1,37 @@ +package spark.streaming + +import spark.{Logging, SparkEnv} +import java.util.concurrent.Executors + + +class JobManager2(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { + + class JobHandler(ssc: SparkStreamContext, job: Job) extends Runnable { + def run() { + SparkEnv.set(ssc.env) + try { + logInfo("Starting " + job) + job.run() + logInfo("Finished " + job) + if (job.time.isInstanceOf[LongTime]) { + val longTime = job.time.asInstanceOf[LongTime] + logInfo("Total notification + skew + processing delay for " + longTime + " is " + + (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") + if (System.getProperty("spark.stream.distributed", "false") == "true") { + TestInputBlockTracker.setEndTime(job.time) + } + } + } catch { + case e: Exception => logError("SparkStream job failed", e) + } + } + } + + initLogging() + + val jobExecutor = Executors.newFixedThreadPool(numThreads) + + def runJob(job: Job) { + jobExecutor.execute(new JobHandler(ssc, job)) + } +} diff --git a/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala b/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala new file mode 100644 index 0000000000..efd4689cf0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala @@ -0,0 +1,184 @@ +package spark.streaming + +import spark.Logging +import spark.storage.StorageLevel + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer} +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.io.BufferedWriter +import java.io.OutputStreamWriter + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +/*import akka.actor.Actor._*/ + +class NetworkStreamReceiver[T: ClassManifest] ( + inputName: String, + intervalDuration: Time, + splitId: Int, + ssc: SparkStreamContext, + tempDirectory: String) + extends DaemonActor + with Logging { + + /** + * Assume all data coming in has non-decreasing timestamp. + */ + final class Inbox[T: ClassManifest] (intervalDuration: Time) { + var currentBucket: (Interval, ArrayBuffer[T]) = null + val filledBuckets = new Queue[(Interval, ArrayBuffer[T])]() + + def += (tuple: (Time, T)) = addTuple(tuple) + + def addTuple(tuple: (Time, T)) { + val (time, data) = tuple + val interval = getInterval (time) + + filledBuckets.synchronized { + if (currentBucket == null) { + currentBucket = (interval, new ArrayBuffer[T]()) + } + + if (interval != currentBucket._1) { + filledBuckets += currentBucket + currentBucket = (interval, new ArrayBuffer[T]()) + } + + currentBucket._2 += data + } + } + + def getInterval(time: Time): Interval = { + val intervalBegin = time.floor(intervalDuration) + Interval (intervalBegin, intervalBegin + intervalDuration) + } + + def hasFilledBuckets(): Boolean = { + filledBuckets.synchronized { + return filledBuckets.size > 0 + } + } + + def popFilledBucket(): (Interval, ArrayBuffer[T]) = { + filledBuckets.synchronized { + if (filledBuckets.size == 0) { + return null + } + return filledBuckets.dequeue() + } + } + } + + val inbox = new Inbox[T](intervalDuration) + lazy val sparkstreamScheduler = { + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + val url = "akka://spark@%s:%s/user/SparkStreamScheduler".format(host, port) + ssc.actorSystem.actorFor(url) + } + /*sparkstreamScheduler ! Test()*/ + + val intervalDurationMillis = intervalDuration.asInstanceOf[LongTime].milliseconds + val useBlockManager = true + + initLogging() + + override def act() { + // register the InputReceiver + val port = 7078 + RemoteActor.alive(port) + RemoteActor.register(Symbol("NetworkStreamReceiver-"+inputName), self) + logInfo("Registered actor on port " + port) + + loop { + reactWithin (getSleepTime) { + case TIMEOUT => + flushInbox() + case data => + val t = data.asInstanceOf[T] + inbox += (getTimeFromData(t), t) + } + } + } + + def getSleepTime(): Long = { + (System.currentTimeMillis / intervalDurationMillis + 1) * + intervalDurationMillis - System.currentTimeMillis + } + + def getTimeFromData(data: T): Time = { + LongTime(System.currentTimeMillis) + } + + def flushInbox() { + while (inbox.hasFilledBuckets) { + inbox.synchronized { + val (interval, data) = inbox.popFilledBucket() + val dataArray = data.toArray + logInfo("Received " + dataArray.length + " items at interval " + interval) + val reference = { + if (useBlockManager) { + writeToBlockManager(dataArray, interval) + } else { + writeToDisk(dataArray, interval) + } + } + if (reference != null) { + logInfo("Notifying scheduler") + sparkstreamScheduler ! InputGenerated(inputName, interval, reference.toString) + } + } + } + } + + def writeToDisk(data: Array[T], interval: Interval): String = { + try { + // TODO(Haoyuan): For current test, the following writing to file lines could be + // commented. + val fs = new Path(tempDirectory).getFileSystem(new Configuration()) + val inputDir = new Path( + tempDirectory, + inputName + "-" + interval.toFormattedString) + val inputFile = new Path(inputDir, "part-" + splitId) + logInfo("Writing to file " + inputFile) + if (System.getProperty("spark.fake", "false") != "true") { + val writer = new BufferedWriter(new OutputStreamWriter(fs.create(inputFile, true))) + data.foreach(x => writer.write(x.toString + "\n")) + writer.close() + } else { + logInfo("Fake file") + } + inputFile.toString + }catch { + case e: Exception => + logError("Exception writing to file at interval " + interval + ": " + e.getMessage, e) + null + } + } + + def writeToBlockManager(data: Array[T], interval: Interval): String = { + try{ + val blockId = inputName + "-" + interval.toFormattedString + "-" + splitId + if (System.getProperty("spark.fake", "false") != "true") { + logInfo("Writing as block " + blockId ) + ssc.env.blockManager.put(blockId.toString, data.toIterator, StorageLevel.DISK_AND_MEMORY) + } else { + logInfo("Fake block") + } + blockId + } catch { + case e: Exception => + logError("Exception writing to block manager at interval " + interval + ": " + e.getMessage, e) + null + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/RDS.scala b/streaming/src/main/scala/spark/streaming/RDS.scala new file mode 100644 index 0000000000..c8dd1015ed --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/RDS.scala @@ -0,0 +1,607 @@ +package spark.streaming + +import spark.streaming.SparkStreamContext._ + +import spark.RDD +import spark.BlockRDD +import spark.UnionRDD +import spark.Logging +import spark.SparkContext +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import java.net.InetSocketAddress + +abstract class RDS[T: ClassManifest] (@transient val ssc: SparkStreamContext) +extends Logging with Serializable { + + initLogging() + + /* ---------------------------------------------- */ + /* Methods that must be implemented by subclasses */ + /* ---------------------------------------------- */ + + // Time by which the window slides in this RDS + def slideTime: Time + + // List of parent RDSs on which this RDS depends on + def dependencies: List[RDS[_]] + + // Key method that computes RDD for a valid time + def compute (validTime: Time): Option[RDD[T]] + + /* --------------------------------------- */ + /* Other general fields and methods of RDS */ + /* --------------------------------------- */ + + // Variable to store the RDDs generated earlier in time + @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () + + // Variable to be set to the first time seen by the RDS (effective time zero) + private[streaming] var zeroTime: Time = null + + // Variable to specify storage level + private var storageLevel: StorageLevel = StorageLevel.NONE + + // Checkpoint level and checkpoint interval + private var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint + private var checkpointInterval: Time = null + + // Change this RDD's storage level + def persist( + storageLevel: StorageLevel, + checkpointLevel: StorageLevel, + checkpointInterval: Time): RDS[T] = { + if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { + // TODO: not sure this is necessary for RDSes + throw new UnsupportedOperationException( + "Cannot change storage level of an RDS after it was already assigned a level") + } + this.storageLevel = storageLevel + this.checkpointLevel = checkpointLevel + this.checkpointInterval = checkpointInterval + this + } + + def persist(newLevel: StorageLevel): RDS[T] = persist(newLevel, StorageLevel.NONE, null) + + // Turn on the default caching level for this RDD + def persist(): RDS[T] = persist(StorageLevel.MEMORY_ONLY_DESER) + + // Turn on the default caching level for this RDD + def cache(): RDS[T] = persist() + + def isInitialized = (zeroTime != null) + + // This method initializes the RDS by setting the "zero" time, based on which + // the validity of future times is calculated. This method also recursively initializes + // its parent RDSs. + def initialize(firstInterval: Interval) { + if (zeroTime == null) { + zeroTime = firstInterval.beginTime + } + logInfo(this + " initialized") + dependencies.foreach(_.initialize(firstInterval)) + } + + // This method checks whether the 'time' is valid wrt slideTime for generating RDD + private def isTimeValid (time: Time): Boolean = { + if (!isInitialized) + throw new Exception (this.toString + " has not been initialized") + if ((time - zeroTime).isMultipleOf(slideTime)) { + true + } else { + false + } + } + + // This method either retrieves a precomputed RDD of this RDS, + // or computes the RDD (if the time is valid) + def getOrCompute(time: Time): Option[RDD[T]] = { + + // if RDD was already generated, then retrieve it from HashMap + generatedRDDs.get(time) match { + + // If an RDD was already generated and is being reused, then + // probably all RDDs in this RDS will be reused and hence should be cached + case Some(oldRDD) => Some(oldRDD) + + // if RDD was not generated, and if the time is valid + // (based on sliding time of this RDS), then generate the RDD + case None => + if (isTimeValid(time)) { + compute(time) match { + case Some(newRDD) => + if (System.getProperty("spark.fake", "false") != "true" || + newRDD.getStorageLevel == StorageLevel.NONE) { + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + } + } + generatedRDDs.put(time.copy(), newRDD) + Some(newRDD) + case None => + None + } + } else { + None + } + } + } + + // This method generates a SparkStream job for the given time + // and may require to be overriden by subclasses + def generateJob(time: Time): Option[Job] = { + getOrCompute(time) match { + case Some(rdd) => { + val jobFunc = () => { + val emptyFunc = { (iterator: Iterator[T]) => {} } + ssc.sc.runJob(rdd, emptyFunc) + } + Some(new Job(time, jobFunc)) + } + case None => None + } + } + + /* -------------- */ + /* RDS operations */ + /* -------------- */ + + def map[U: ClassManifest](mapFunc: T => U) = new MappedRDS(this, ssc.sc.clean(mapFunc)) + + def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = + new FlatMappedRDS(this, ssc.sc.clean(flatMapFunc)) + + def filter(filterFunc: T => Boolean) = new FilteredRDS(this, filterFunc) + + def glom() = new GlommedRDS(this) + + def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = + new MapPartitionedRDS(this, ssc.sc.clean(mapPartFunc)) + + def reduce(reduceFunc: (T, T) => T) = this.map(x => (1, x)).reduceByKey(reduceFunc, 1).map(_._2) + + def count() = this.map(_ => 1).reduce(_ + _) + + def collect() = this.map(x => (1, x)).groupByKey(1).map(_._2) + + def foreach(foreachFunc: T => Unit) = { + val newrds = new PerElementForEachRDS(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newrds) + newrds + } + + def foreachRDD(foreachFunc: RDD[T] => Unit) = { + val newrds = new PerRDDForEachRDS(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newrds) + newrds + } + + def print() = { + def foreachFunc = (rdd: RDD[T], time: Time) => { + val first11 = rdd.take(11) + println ("-------------------------------------------") + println ("Time: " + time) + println ("-------------------------------------------") + first11.take(10).foreach(println) + if (first11.size > 10) println("...") + println() + } + val newrds = new PerRDDForEachRDS(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newrds) + newrds + } + + def window(windowTime: Time, slideTime: Time) = new WindowedRDS(this, windowTime, slideTime) + + def batch(batchTime: Time) = window(batchTime, batchTime) + + def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time) = + this.window(windowTime, slideTime).reduce(reduceFunc) + + def reduceByWindow( + reduceFunc: (T, T) => T, + invReduceFunc: (T, T) => T, + windowTime: Time, + slideTime: Time) = { + this.map(x => (1, x)) + .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1) + .map(_._2) + } + + def countByWindow(windowTime: Time, slideTime: Time) = { + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime) + } + + def union(that: RDS[T]) = new UnifiedRDS(Array(this, that)) + + def register() = ssc.registerOutputStream(this) +} + + +class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) +extends Serializable { + + def ssc = rds.ssc + + /* ---------------------------------- */ + /* RDS operations for key-value pairs */ + /* ---------------------------------- */ + + def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + def createCombiner(v: V) = ArrayBuffer[V](v) + def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) + def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) + combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) + } + + private def combineByKey[C: ClassManifest]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + numPartitions: Int) : ShuffledRDS[K, V, C] = { + new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def groupByKeyAndWindow( + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + rds.window(windowTime, slideTime).groupByKey(numPartitions) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) + } + + // This method is the efficient sliding window reduce operation, + // which requires the specification of an inverse reduce function, + // so that new elements introduced in the window can be "added" using + // reduceFunc to the previous window's result and old elements can be + // "subtracted using invReduceFunc. + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int): ReducedWindowedRDS[K, V] = { + + new ReducedWindowedRDS[K, V]( + rds, + ssc.sc.clean(reduceFunc), + ssc.sc.clean(invReduceFunc), + windowTime, + slideTime, + numPartitions) + } +} + + +abstract class InputRDS[T: ClassManifest] ( + val inputName: String, + val batchDuration: Time, + ssc: SparkStreamContext) +extends RDS[T](ssc) { + + override def dependencies = List() + + override def slideTime = batchDuration + + def setReference(time: Time, reference: AnyRef) +} + + +class FileInputRDS( + val fileInputName: String, + val directory: String, + ssc: SparkStreamContext) +extends InputRDS[String](fileInputName, LongTime(1000), ssc) { + + @transient val generatedFiles = new HashMap[Time,String] + + // TODO(Haoyuan): This is for the performance test. + @transient + val rdd = ssc.sc.textFile(SparkContext.inputFile, + SparkContext.idealPartitions).asInstanceOf[RDD[String]] + + override def compute(validTime: Time): Option[RDD[String]] = { + generatedFiles.get(validTime) match { + case Some(file) => + logInfo("Reading from file " + file + " for time " + validTime) + // Some(ssc.sc.textFile(file).asInstanceOf[RDD[String]]) + // The following line is for HDFS performance test. Sould comment out the above line. + Some(rdd) + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + generatedFiles += ((time, reference.toString)) + logInfo("Reference added for time " + time + " - " + reference.toString) + } +} + +class NetworkInputRDS[T: ClassManifest]( + val networkInputName: String, + val addresses: Array[InetSocketAddress], + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[T](networkInputName, batchDuration, ssc) { + + + // TODO(Haoyuan): This is for the performance test. + @transient var rdd: RDD[T] = null + + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Running initial count to cache fake RDD") + rdd = ssc.sc.textFile(SparkContext.inputFile, + SparkContext.idealPartitions).asInstanceOf[RDD[T]] + val fakeCacheLevel = System.getProperty("spark.fake.cache", "") + if (fakeCacheLevel == "MEMORY_ONLY_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel != "") { + logError("Invalid fake cache level: " + fakeCacheLevel) + System.exit(1) + } + rdd.count() + } + + @transient val references = new HashMap[Time,String] + + override def compute(validTime: Time): Option[RDD[T]] = { + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Returning fake RDD at " + validTime) + return Some(rdd) + } + references.get(validTime) match { + case Some(reference) => + if (reference.startsWith("file") || reference.startsWith("hdfs")) { + logInfo("Reading from file " + reference + " for time " + validTime) + Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) + } else { + logInfo("Getting from BlockManager " + reference + " for time " + validTime) + Some(new BlockRDD(ssc.sc, Array(reference))) + } + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.toString)) + logInfo("Reference added for time " + time + " - " + reference.toString) + } +} + + +class TestInputRDS( + val testInputName: String, + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[String](testInputName, batchDuration, ssc) { + + @transient val references = new HashMap[Time,Array[String]] + + override def compute(validTime: Time): Option[RDD[String]] = { + references.get(validTime) match { + case Some(reference) => + Some(new BlockRDD[String](ssc.sc, reference)) + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.asInstanceOf[Array[String]])) + } +} + + +class MappedRDS[T: ClassManifest, U: ClassManifest] ( + parent: RDS[T], + mapFunc: T => U) +extends RDS[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.map[U](mapFunc)) + } +} + + +class FlatMappedRDS[T: ClassManifest, U: ClassManifest]( + parent: RDS[T], + flatMapFunc: T => Traversable[U]) +extends RDS[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) + } +} + + +class FilteredRDS[T: ClassManifest](parent: RDS[T], filterFunc: T => Boolean) +extends RDS[T](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + parent.getOrCompute(validTime).map(_.filter(filterFunc)) + } +} + +class MapPartitionedRDS[T: ClassManifest, U: ClassManifest]( + parent: RDS[T], + mapPartFunc: Iterator[T] => Iterator[U]) +extends RDS[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc)) + } +} + +class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Array[T]]] = { + parent.getOrCompute(validTime).map(_.glom()) + } +} + + +class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( + parent: RDS[(K,V)], + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + numPartitions: Int) + extends RDS [(K,C)] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K,C)]] = { + parent.getOrCompute(validTime) match { + case Some(rdd) => + val newrdd = { + if (numPartitions > 0) { + rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, numPartitions) + } else { + rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner) + } + } + Some(newrdd) + case None => None + } + } +} + + +class UnifiedRDS[T: ClassManifest](parents: Array[RDS[T]]) +extends RDS[T](parents(0).ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different SparkStreamContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime: Time = parents(0).slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val rdds = new ArrayBuffer[RDD[T]]() + parents.map(_.getOrCompute(validTime)).foreach(_ match { + case Some(rdd) => rdds += rdd + case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) + }) + if (rdds.size > 0) { + Some(new UnionRDD(ssc.sc, rdds)) + } else { + None + } + } +} + + +class PerElementForEachRDS[T: ClassManifest] ( + parent: RDS[T], + foreachFunc: T => Unit) +extends RDS[Unit](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + val sparkJobFunc = { + (iterator: Iterator[T]) => iterator.foreach(foreachFunc) + } + ssc.sc.runJob(rdd, sparkJobFunc) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} + + +class PerRDDForEachRDS[T: ClassManifest] ( + parent: RDS[T], + foreachFunc: (RDD[T], Time) => Unit) +extends RDS[Unit](parent.ssc) { + + def this(parent: RDS[T], altForeachFunc: (RDD[T]) => Unit) = + this(parent, (rdd: RDD[T], time: Time) => altForeachFunc(rdd)) + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + foreachFunc(rdd, time) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedRDS.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedRDS.scala new file mode 100644 index 0000000000..dd1f474657 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedRDS.scala @@ -0,0 +1,218 @@ +package spark.streaming + +import spark.streaming.SparkStreamContext._ + +import spark.RDD +import spark.UnionRDD +import spark.CoGroupedRDD +import spark.HashPartitioner +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer + +class ReducedWindowedRDS[K: ClassManifest, V: ClassManifest]( + parent: RDS[(K, V)], + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + _windowTime: Time, + _slideTime: Time, + numPartitions: Int) +extends RDS[(K,V)](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of ReducedWindowedRDS (" + _slideTime + ") " + + "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of ReducedWindowedRDS (" + _slideTime + ") " + + "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") + + val reducedRDS = parent.reduceByKey(reduceFunc, numPartitions) + val allowPartialWindows = true + //reducedRDS.persist(StorageLevel.MEMORY_ONLY_DESER_2) + + override def dependencies = List(reducedRDS) + + def windowTime: Time = _windowTime + + override def slideTime: Time = _slideTime + + override def persist( + storageLevel: StorageLevel, + checkpointLevel: StorageLevel, + checkpointInterval: Time): RDS[(K,V)] = { + super.persist(storageLevel, checkpointLevel, checkpointInterval) + reducedRDS.persist(storageLevel, checkpointLevel, checkpointInterval) + } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + + + // Notation: + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old time steps new time steps + // + def getAdjustedWindow(endTime: Time, windowTime: Time): Interval = { + val beginTime = + if (allowPartialWindows && endTime - windowTime < parent.zeroTime) { + parent.zeroTime + } else { + endTime - windowTime + } + Interval(beginTime, endTime) + } + + val currentTime = validTime.copy + val currentWindow = getAdjustedWindow(currentTime, windowTime) + val previousWindow = getAdjustedWindow(currentTime - slideTime, windowTime) + + logInfo("Current window = " + currentWindow) + logInfo("Previous window = " + previousWindow) + logInfo("Parent.zeroTime = " + parent.zeroTime) + + if (allowPartialWindows) { + if (currentTime - slideTime == parent.zeroTime) { + reducedRDS.getOrCompute(currentTime) match { + case Some(rdd) => return Some(rdd) + case None => throw new Exception("Could not get first reduced RDD for time " + currentTime) + } + } + } else { + if (previousWindow.beginTime < parent.zeroTime) { + if (currentWindow.beginTime < parent.zeroTime) { + return None + } else { + // If this is the first feasible window, then generate reduced value in the naive manner + val reducedRDDs = new ArrayBuffer[RDD[(K, V)]]() + var t = currentWindow.endTime + while (t > currentWindow.beginTime) { + reducedRDS.getOrCompute(t) match { + case Some(rdd) => reducedRDDs += rdd + case None => throw new Exception("Could not get reduced RDD for time " + t) + } + t -= reducedRDS.slideTime + } + if (reducedRDDs.size == 0) { + throw new Exception("Could not generate the first RDD for time " + validTime) + } + return Some(new UnionRDD(ssc.sc, reducedRDDs).reduceByKey(reduceFunc, numPartitions)) + } + } + } + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = getOrCompute(previousWindow.endTime) match { + case Some(rdd) => rdd.asInstanceOf[RDD[(_, _)]] + case None => throw new Exception("Could not get previous RDD for time " + previousWindow.endTime) + } + + val oldRDDs = new ArrayBuffer[RDD[(_, _)]]() + val newRDDs = new ArrayBuffer[RDD[(_, _)]]() + + // Get the RDDs of the reduced values in "old time steps" + var t = currentWindow.beginTime + while (t > previousWindow.beginTime) { + reducedRDS.getOrCompute(t) match { + case Some(rdd) => oldRDDs += rdd.asInstanceOf[RDD[(_, _)]] + case None => throw new Exception("Could not get old reduced RDD for time " + t) + } + t -= reducedRDS.slideTime + } + + // Get the RDDs of the reduced values in "new time steps" + t = currentWindow.endTime + while (t > previousWindow.endTime) { + reducedRDS.getOrCompute(t) match { + case Some(rdd) => newRDDs += rdd.asInstanceOf[RDD[(_, _)]] + case None => throw new Exception("Could not get new reduced RDD for time " + t) + } + t -= reducedRDS.slideTime + } + + val partitioner = new HashPartitioner(numPartitions) + val allRDDs = new ArrayBuffer[RDD[(_, _)]]() + allRDDs += previousWindowRDD + allRDDs ++= oldRDDs + allRDDs ++= newRDDs + + + val numOldRDDs = oldRDDs.size + val numNewRDDs = newRDDs.size + logInfo("Generated numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) + logInfo("Generating CoGroupedRDD with " + allRDDs.size + " RDDs") + val newRDD = new CoGroupedRDD[K](allRDDs.toSeq, partitioner).asInstanceOf[RDD[(K,Seq[Seq[V]])]].map(x => { + val (key, value) = x + logDebug("value.size = " + value.size + ", numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) + if (value.size != 1 + numOldRDDs + numNewRDDs) { + throw new Exception("Number of groups not odd!") + } + + // old values = reduced values of the "old time steps" that are eliminated from current window + // new values = reduced values of the "new time steps" that are introduced to the current window + // previous value = reduced value of the previous window + + /*val numOldValues = (value.size - 1) / 2*/ + // Getting reduced values "old time steps" + val oldValues = + (0 until numOldRDDs).map(i => value(1 + i)).filter(_.size > 0).map(x => x(0)) + // Getting reduced values "new time steps" + val newValues = + (0 until numNewRDDs).map(i => value(1 + numOldRDDs + i)).filter(_.size > 0).map(x => x(0)) + + // If reduced value for the key does not exist in previous window, it should not exist in "old time steps" + if (value(0).size == 0 && oldValues.size != 0) { + throw new Exception("Unexpected: Key exists in old reduced values but not in previous reduced values") + } + + // For the key, at least one of "old time steps", "new time steps" and previous window should have reduced values + if (value(0).size == 0 && oldValues.size == 0 && newValues.size == 0) { + throw new Exception("Unexpected: Key does not exist in any of old, new, or previour reduced values") + } + + // Logic to generate the final reduced value for current window: + // + // If previous window did not have reduced value for the key + // Then, return reduced value of "new time steps" as the final value + // Else, reduced value exists in previous window + // If "old" time steps did not have reduced value for the key + // Then, reduce previous window's reduced value with that of "new time steps" for final value + // Else, reduced values exists in "old time steps" + // If "new values" did not have reduced value for the key + // Then, inverse-reduce "old values" from previous window's reduced value for final value + // Else, all 3 values exist, combine all of them together + // + logDebug("# old values = " + oldValues.size + ", # new values = " + newValues) + val finalValue = { + if (value(0).size == 0) { + newValues.reduce(reduceFunc) + } else { + val prevValue = value(0)(0) + logDebug("prev value = " + prevValue) + if (oldValues.size == 0) { + // assuming newValue.size > 0 (all 3 cannot be zero, as checked earlier) + val temp = newValues.reduce(reduceFunc) + reduceFunc(prevValue, temp) + } else if (newValues.size == 0) { + invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) + } else { + val tempValue = invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) + reduceFunc(tempValue, newValues.reduce(reduceFunc)) + } + } + } + (key, finalValue) + }) + //newRDD.persist(StorageLevel.MEMORY_ONLY_DESER_2) + Some(newRDD) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala new file mode 100644 index 0000000000..4137d8f27d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -0,0 +1,181 @@ +package spark.streaming + +import spark.SparkEnv +import spark.Logging + +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.collection.mutable.ArrayBuffer + +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ +import akka.util.duration._ + +sealed trait SchedulerMessage +case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage +case class Test extends SchedulerMessage + +class Scheduler( + ssc: SparkStreamContext, + inputRDSs: Array[InputRDS[_]], + outputRDSs: Array[RDS[_]]) +extends Actor with Logging { + + class InputState (inputNames: Array[String]) { + val inputsLeft = new HashSet[String]() + inputsLeft ++= inputNames + + val startTime = System.currentTimeMillis + + def delay() = System.currentTimeMillis - startTime + + def addGeneratedInput(inputName: String) = inputsLeft -= inputName + + def areAllInputsGenerated() = (inputsLeft.size == 0) + + override def toString(): String = { + val left = if (inputsLeft.size == 0) "" else inputsLeft.reduceLeft(_ + ", " + _) + return "Inputs left = [ " + left + " ]" + } + } + + + initLogging() + + val inputNames = inputRDSs.map(_.inputName).toArray + val inputStates = new HashMap[Interval, InputState]() + val currentJobs = System.getProperty("spark.stream.currentJobs", "1").toInt + val jobManager = new JobManager2(ssc, currentJobs) + + // TODO(Haoyuan): The following line is for performance test only. + var cnt: Int = System.getProperty("spark.stream.fake.cnt", "60").toInt + var lastInterval: Interval = null + + + /*remote.register("SparkStreamScheduler", actorOf[Scheduler])*/ + logInfo("Registered actor on port ") + + /*jobManager.start()*/ + startStreamReceivers() + + def receive = { + case InputGenerated(inputName, interval, reference) => { + addGeneratedInput(inputName, interval, reference) + } + case Test() => logInfo("TEST PASSED") + } + + def addGeneratedInput(inputName: String, interval: Interval, reference: AnyRef = null) { + logInfo("Input " + inputName + " generated for interval " + interval) + inputStates.get(interval) match { + case None => inputStates.put(interval, new InputState(inputNames)) + case _ => + } + inputStates(interval).addGeneratedInput(inputName) + + inputRDSs.filter(_.inputName == inputName).foreach(inputRDS => { + inputRDS.setReference(interval.endTime, reference) + if (inputRDS.isInstanceOf[TestInputRDS]) { + TestInputBlockTracker.addBlocks(interval.endTime, reference) + } + } + ) + + def getNextInterval(): Option[Interval] = { + logDebug("Last interval is " + lastInterval) + val readyIntervals = inputStates.filter(_._2.areAllInputsGenerated).keys + /*inputState.foreach(println) */ + logDebug("InputState has " + inputStates.size + " intervals, " + readyIntervals.size + " ready intervals") + return readyIntervals.find(lastInterval == null || _.beginTime == lastInterval.endTime) + } + + var nextInterval = getNextInterval() + var count = 0 + while(nextInterval.isDefined) { + val inputState = inputStates.get(nextInterval.get).get + generateRDDsForInterval(nextInterval.get) + logInfo("Skew delay for " + nextInterval.get.endTime + " is " + (inputState.delay / 1000.0) + " s") + inputStates.remove(nextInterval.get) + lastInterval = nextInterval.get + nextInterval = getNextInterval() + count += 1 + /*if (nextInterval.size == 0 && inputState.size > 0) { + logDebug("Next interval not ready, pending intervals " + inputState.size) + }*/ + } + logDebug("RDDs generated for " + count + " intervals") + + /* + if (inputState(interval).areAllInputsGenerated) { + generateRDDsForInterval(interval) + lastInterval = interval + inputState.remove(interval) + } else { + logInfo("All inputs not generated for interval " + interval) + } + */ + } + + def generateRDDsForInterval (interval: Interval) { + logInfo("Generating RDDs for interval " + interval) + outputRDSs.foreach(outputRDS => { + if (!outputRDS.isInitialized) outputRDS.initialize(interval) + outputRDS.generateJob(interval.endTime) match { + case Some(job) => submitJob(job) + case None => + } + } + ) + // TODO(Haoyuan): This comment is for performance test only. + if (System.getProperty("spark.fake", "false") == "true" || System.getProperty("spark.stream.fake", "false") == "true") { + cnt -= 1 + if (cnt <= 0) { + logInfo("My time is up! " + cnt) + System.exit(1) + } + } + } + + def submitJob(job: Job) { + logInfo("Submitting " + job + " to JobManager") + /*jobManager ! RunJob(job)*/ + jobManager.runJob(job) + } + + def startStreamReceivers() { + val testStreamReceiverNames = new ArrayBuffer[(String, Long)]() + inputRDSs.foreach (inputRDS => { + inputRDS match { + case fileInputRDS: FileInputRDS => { + val fileStreamReceiver = new FileStreamReceiver( + fileInputRDS.inputName, + fileInputRDS.directory, + fileInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds) + fileStreamReceiver.start() + } + case networkInputRDS: NetworkInputRDS[_] => { + val networkStreamReceiver = new NetworkStreamReceiver( + networkInputRDS.inputName, + networkInputRDS.batchDuration, + 0, + ssc, + if (ssc.tempDir == null) null else ssc.tempDir.toString) + networkStreamReceiver.start() + } + case testInputRDS: TestInputRDS => { + testStreamReceiverNames += + ((testInputRDS.inputName, testInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds)) + } + } + }) + if (testStreamReceiverNames.size > 0) { + /*val testStreamCoordinator = new TestStreamCoordinator(testStreamReceiverNames.toArray)*/ + /*testStreamCoordinator.start()*/ + val actor = ssc.actorSystem.actorOf( + Props(new TestStreamCoordinator(testStreamReceiverNames.toArray)), + name = "TestStreamCoordinator") + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/SenGeneratorForPerformanceTest.scala b/streaming/src/main/scala/spark/streaming/SenGeneratorForPerformanceTest.scala new file mode 100644 index 0000000000..bb32089ae2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SenGeneratorForPerformanceTest.scala @@ -0,0 +1,78 @@ +package spark.streaming + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + +/*import akka.actor.Actor._*/ +/*import akka.actor.ActorRef*/ + + +object SenGeneratorForPerformanceTest { + + def printUsage () { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val inputManagerIP = args(0) + val inputManagerPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = { + if (args.length > 3) args(3).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + try { + /*val inputManager = remote.actorFor("InputReceiver-Sentences",*/ + /* inputManagerIP, inputManagerPort)*/ + val inputManager = select(Node(inputManagerIP, inputManagerPort), Symbol("InputReceiver-Sentences")) + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + println ("Sending " + sentencesPerSecond + " sentences per second to " + inputManagerIP + ":" + inputManagerPort) + var lastPrintTime = System.currentTimeMillis() + var count = 0 + + while (true) { + /*if (!inputManager.tryTell (lines (random.nextInt (lines.length))))*/ + /*throw new Exception ("disconnected")*/ +// inputManager ! lines (random.nextInt (lines.length)) + for (i <- 0 to sentencesPerSecond) inputManager ! lines (0) + println(System.currentTimeMillis / 1000 + " s") +/* count += 1 + + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + + Thread.sleep (sleepBetweenSentences.toLong) +*/ + val currentMs = System.currentTimeMillis / 1000; + Thread.sleep ((currentMs * 1000 + 1000) - System.currentTimeMillis) + } + } catch { + case e: Exception => + /*Thread.sleep (1000)*/ + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/SenderReceiverTest.scala new file mode 100644 index 0000000000..6af270298a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SenderReceiverTest.scala @@ -0,0 +1,63 @@ +package spark.streaming +import java.net.{Socket, ServerSocket} +import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} + +object Receiver { + def main(args: Array[String]) { + val port = args(0).toInt + val lsocket = new ServerSocket(port) + println("Listening on port " + port ) + while(true) { + val socket = lsocket.accept() + (new Thread() { + override def run() { + val buffer = new Array[Byte](100000) + var count = 0 + val time = System.currentTimeMillis + try { + val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) + var loop = true + var string: String = null + while((string = is.readUTF) != null) { + count += 28 + } + } catch { + case e: Exception => e.printStackTrace + } + val timeTaken = System.currentTimeMillis - time + val tput = (count / 1024.0) / (timeTaken / 1000.0) + println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") + } + }).start() + } + } + +} + +object Sender { + + def main(args: Array[String]) { + try { + val host = args(0) + val port = args(1).toInt + val size = args(2).toInt + + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) + val bytes = byteStream.toByteArray() + println("Generated array of " + bytes.length + " bytes") + + /*val bytes = new Array[Byte](size)*/ + val socket = new Socket(host, port) + val os = socket.getOutputStream + os.write(bytes) + os.flush + socket.close() + + } catch { + case e: Exception => e.printStackTrace + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/SentenceFileGenerator.scala new file mode 100644 index 0000000000..15858f59e3 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SentenceFileGenerator.scala @@ -0,0 +1,92 @@ +package spark.streaming + +import spark._ + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.io.Source + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +object SentenceFileGenerator { + + def printUsage () { + println ("Usage: SentenceFileGenerator <# partitions> []") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val master = args(0) + val fs = new Path(args(1)).getFileSystem(new Configuration()) + val targetDirectory = new Path(args(1)).makeQualified(fs) + val numPartitions = args(2).toInt + val sentenceFile = args(3) + val sentencesPerSecond = { + if (args.length > 4) args(4).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n").toArray + source.close () + println("Read " + lines.length + " lines from file " + sentenceFile) + + val sentences = { + val buffer = ArrayBuffer[String]() + val random = new Random() + var i = 0 + while (i < sentencesPerSecond) { + buffer += lines(random.nextInt(lines.length)) + i += 1 + } + buffer.toArray + } + println("Generated " + sentences.length + " sentences") + + val sc = new SparkContext(master, "SentenceFileGenerator") + val sentencesRDD = sc.parallelize(sentences, numPartitions) + + val tempDirectory = new Path(targetDirectory, "_tmp") + + fs.mkdirs(targetDirectory) + fs.mkdirs(tempDirectory) + + var saveTimeMillis = System.currentTimeMillis + try { + while (true) { + val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) + val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) + println("Writing to file " + newDir) + sentencesRDD.saveAsTextFile(tmpNewDir.toString) + fs.rename(tmpNewDir, newDir) + saveTimeMillis += 1000 + val sleepTimeMillis = { + val currentTimeMillis = System.currentTimeMillis + if (saveTimeMillis < currentTimeMillis) { + 0 + } else { + saveTimeMillis - currentTimeMillis + } + } + println("Sleeping for " + sleepTimeMillis + " ms") + Thread.sleep(sleepTimeMillis) + } + } catch { + case e: Exception => + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/SentenceGenerator.scala b/streaming/src/main/scala/spark/streaming/SentenceGenerator.scala new file mode 100644 index 0000000000..a9f124d2d7 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SentenceGenerator.scala @@ -0,0 +1,103 @@ +package spark.streaming + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + + +object SentenceGenerator { + + def printUsage { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + + def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + + try { + var lastPrintTime = System.currentTimeMillis() + var count = 0 + while(true) { + streamReceiver ! lines(random.nextInt(lines.length)) + count += 1 + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + Thread.sleep(sleepBetweenSentences.toLong) + } + } catch { + case e: Exception => + } + } + + def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + try { + val numSentences = if (sentencesPerSecond <= 0) { + lines.length + } else { + sentencesPerSecond + } + var nextSendingTime = System.currentTimeMillis() + val pingInterval = if (System.getenv("INTERVAL") != null) { + System.getenv("INTERVAL").toInt + } else { + 2000 + } + while(true) { + (0 until numSentences).foreach(i => { + streamReceiver ! lines(i % lines.length) + }) + println ("Sent " + numSentences + " sentences") + nextSendingTime += pingInterval + val sleepTime = nextSendingTime - System.currentTimeMillis + if (sleepTime > 0) { + println ("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } + } + } catch { + case e: Exception => + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val generateRandomly = false + + val streamReceiverIP = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 + val sentenceInputName = if (args.length > 4) args(4) else "Sentences" + + println("Sending " + sentencesPerSecond + " sentences per second to " + + streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + val streamReceiver = select( + Node(streamReceiverIP, streamReceiverPort), + Symbol("NetworkStreamReceiver-" + sentenceInputName)) + if (generateRandomly) { + generateRandomSentences(lines, sentencesPerSecond, streamReceiver) + } else { + generateSameSentences(lines, sentencesPerSecond, streamReceiver) + } + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/ShuffleTest.scala new file mode 100644 index 0000000000..32aa4144a0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ShuffleTest.scala @@ -0,0 +1,22 @@ +package spark.streaming +import spark.SparkContext +import SparkContext._ + +object ShuffleTest { + def main(args: Array[String]) { + + if (args.length < 1) { + println ("Usage: ShuffleTest ") + System.exit(1) + } + + val sc = new spark.SparkContext(args(0), "ShuffleTest") + val rdd = sc.parallelize(1 to 1000, 500).cache + + def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } + + time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } + System.exit(0) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/SimpleWordCount.scala b/streaming/src/main/scala/spark/streaming/SimpleWordCount.scala new file mode 100644 index 0000000000..a75ccd3a56 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SimpleWordCount.scala @@ -0,0 +1,30 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +object SimpleWordCount { + + def main (args: Array[String]) { + + if (args.length < 1) { + println ("Usage: SparkStreamContext []") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount") + if (args.length > 1) { + ssc.setTempDir(args(1)) + } + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) + /*sentences.print*/ + + val words = sentences.flatMap(_.split(" ")) + + val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1) + counts.print + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/SimpleWordCount2.scala b/streaming/src/main/scala/spark/streaming/SimpleWordCount2.scala new file mode 100644 index 0000000000..9672e64b13 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SimpleWordCount2.scala @@ -0,0 +1,51 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import scala.util.Sorting + +object SimpleWordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SimpleWordCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + + val words = sentences.flatMap(_.split(" ")) + + val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10) + counts.foreachRDD(_.collect()) + /*words.foreachRDD(_.countByValue())*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/SimpleWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/SimpleWordCount2_Special.scala new file mode 100644 index 0000000000..503033a8e5 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SimpleWordCount2_Special.scala @@ -0,0 +1,83 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import scala.collection.JavaConversions.mapAsScalaMap +import scala.util.Sorting +import java.lang.{Long => JLong} + +object SimpleWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SimpleWordCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 400)).toArray + ) + + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } + + + /*val words = sentences.flatMap(_.split(" "))*/ + /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10)*/ + val counts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + counts.foreachRDD(_.collect()) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala new file mode 100644 index 0000000000..51f8193740 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala @@ -0,0 +1,105 @@ +package spark.streaming + +import spark.SparkContext +import spark.SparkEnv +import spark.Utils +import spark.Logging + +import scala.collection.mutable.ArrayBuffer + +import java.net.InetSocketAddress +import java.io.IOException +import java.util.UUID + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +import akka.actor._ +import akka.actor.Actor +import akka.util.duration._ + +class SparkStreamContext ( + master: String, + frameworkName: String, + val sparkHome: String = null, + val jars: Seq[String] = Nil) + extends Logging { + + initLogging() + + val sc = new SparkContext(master, frameworkName, sparkHome, jars) + val env = SparkEnv.get + val actorSystem = env.actorSystem + + @transient val inputRDSs = new ArrayBuffer[InputRDS[_]]() + @transient val outputRDSs = new ArrayBuffer[RDS[_]]() + + var tempDirRoot: String = null + var tempDir: Path = null + + def readNetworkStream[T: ClassManifest]( + name: String, + addresses: Array[InetSocketAddress], + batchDuration: Time): RDS[T] = { + + val inputRDS = new NetworkInputRDS[T](name, addresses, batchDuration, this) + inputRDSs += inputRDS + inputRDS + } + + def readNetworkStream[T: ClassManifest]( + name: String, + addresses: Array[String], + batchDuration: Long): RDS[T] = { + + def stringToInetSocketAddress (str: String): InetSocketAddress = { + val parts = str.split(":") + if (parts.length != 2) { + throw new IllegalArgumentException ("Address format error") + } + new InetSocketAddress(parts(0), parts(1).toInt) + } + + readNetworkStream( + name, + addresses.map(stringToInetSocketAddress).toArray, + LongTime(batchDuration)) + } + + def readFileStream(name: String, directory: String): RDS[String] = { + val path = new Path(directory) + val fs = path.getFileSystem(new Configuration()) + val qualPath = path.makeQualified(fs) + val inputRDS = new FileInputRDS(name, qualPath.toString, this) + inputRDSs += inputRDS + inputRDS + } + + def readTestStream(name: String, batchDuration: Long): RDS[String] = { + val inputRDS = new TestInputRDS(name, LongTime(batchDuration), this) + inputRDSs += inputRDS + inputRDS + } + + def registerOutputStream (outputRDS: RDS[_]) { + outputRDSs += outputRDS + } + + def setTempDir(dir: String) { + tempDirRoot = dir + } + + def run () { + val ctxt = this + val actor = actorSystem.actorOf( + Props(new Scheduler(ctxt, inputRDSs.toArray, outputRDSs.toArray)), + name = "SparkStreamScheduler") + logInfo("Registered actor") + actorSystem.awaitTermination() + } +} + +object SparkStreamContext { + implicit def rdsToPairRdsFunctions [K: ClassManifest, V: ClassManifest] (rds: RDS[(K,V)]) = + new PairRDSFunctions (rds) +} diff --git a/streaming/src/main/scala/spark/streaming/TestGenerator.scala b/streaming/src/main/scala/spark/streaming/TestGenerator.scala new file mode 100644 index 0000000000..0ff6af61f2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TestGenerator.scala @@ -0,0 +1,107 @@ +package spark.streaming + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + + +object TestGenerator { + + def printUsage { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + /* + def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + + try { + var lastPrintTime = System.currentTimeMillis() + var count = 0 + while(true) { + streamReceiver ! lines(random.nextInt(lines.length)) + count += 1 + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + Thread.sleep(sleepBetweenSentences.toLong) + } + } catch { + case e: Exception => + } + }*/ + + def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + try { + val numSentences = if (sentencesPerSecond <= 0) { + lines.length + } else { + sentencesPerSecond + } + val sentences = lines.take(numSentences).toArray + + var nextSendingTime = System.currentTimeMillis() + val sendAsArray = true + while(true) { + if (sendAsArray) { + println("Sending as array") + streamReceiver !? sentences + } else { + println("Sending individually") + sentences.foreach(sentence => { + streamReceiver !? sentence + }) + } + println ("Sent " + numSentences + " sentences in " + (System.currentTimeMillis - nextSendingTime) + " ms") + nextSendingTime += 1000 + val sleepTime = nextSendingTime - System.currentTimeMillis + if (sleepTime > 0) { + println ("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } + } + } catch { + case e: Exception => + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val generateRandomly = false + + val streamReceiverIP = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 + val sentenceInputName = if (args.length > 4) args(4) else "Sentences" + + println("Sending " + sentencesPerSecond + " sentences per second to " + + streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + val streamReceiver = select( + Node(streamReceiverIP, streamReceiverPort), + Symbol("NetworkStreamReceiver-" + sentenceInputName)) + if (generateRandomly) { + /*generateRandomSentences(lines, sentencesPerSecond, streamReceiver)*/ + } else { + generateSameSentences(lines, sentencesPerSecond, streamReceiver) + } + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/TestGenerator2.scala b/streaming/src/main/scala/spark/streaming/TestGenerator2.scala new file mode 100644 index 0000000000..00d43604d0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TestGenerator2.scala @@ -0,0 +1,119 @@ +package spark.streaming + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.io.{DataOutputStream, ByteArrayOutputStream, DataInputStream} +import java.net.Socket + +object TestGenerator2 { + + def printUsage { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + + def sendSentences(streamReceiverHost: String, streamReceiverPort: Int, numSentences: Int, bytes: Array[Byte], intervalTime: Long){ + try { + println("Connecting to " + streamReceiverHost + ":" + streamReceiverPort) + val socket = new Socket(streamReceiverHost, streamReceiverPort) + + println("Sending " + numSentences+ " sentences / " + (bytes.length / 1024.0 / 1024.0) + " MB per " + intervalTime + " ms to " + streamReceiverHost + ":" + streamReceiverPort ) + val currentTime = System.currentTimeMillis + var targetTime = (currentTime / intervalTime + 1).toLong * intervalTime + Thread.sleep(targetTime - currentTime) + + while(true) { + val startTime = System.currentTimeMillis() + println("Sending at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") + val socketOutputStream = socket.getOutputStream + val parts = 10 + (0 until parts).foreach(i => { + val partStartTime = System.currentTimeMillis + + val offset = (i * bytes.length / parts).toInt + val len = math.min(((i + 1) * bytes.length / parts).toInt - offset, bytes.length) + socketOutputStream.write(bytes, offset, len) + socketOutputStream.flush() + val partFinishTime = System.currentTimeMillis + println("Sending part " + i + " of " + len + " bytes took " + (partFinishTime - partStartTime) + " ms") + val sleepTime = math.max(0, 1000 / parts - (partFinishTime - partStartTime) - 1) + Thread.sleep(sleepTime) + }) + + socketOutputStream.flush() + /*val socketInputStream = new DataInputStream(socket.getInputStream)*/ + /*val reply = socketInputStream.readUTF()*/ + val finishTime = System.currentTimeMillis() + println ("Sent " + bytes.length + " bytes in " + (finishTime - startTime) + " ms for interval [" + targetTime + ", " + (targetTime + intervalTime) + "]") + /*println("Received = " + reply)*/ + targetTime = targetTime + intervalTime + val sleepTime = (targetTime - finishTime) + 10 + if (sleepTime > 0) { + println("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } else { + println("############################") + println("###### Skipping sleep ######") + println("############################") + } + } + } catch { + case e: Exception => println(e) + } + println("Stopped sending") + } + + def main(args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val streamReceiverHost = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val intervalTime = args(3).toLong + val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 + + println("Reading the file " + sentenceFile) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close() + + val numSentences = if (sentencesPerInterval <= 0) { + lines.length + } else { + sentencesPerInterval + } + + println("Generating sentences") + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take(numSentences).toArray + } else { + (0 until numSentences).map(i => lines(i % lines.length)).toArray + } + + println("Converting to byte array") + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + /*stringDataStream.writeInt(sentences.size)*/ + sentences.foreach(stringDataStream.writeUTF) + val bytes = byteStream.toByteArray() + stringDataStream.close() + println("Generated array of " + bytes.length + " bytes") + + /*while(true) { */ + sendSentences(streamReceiverHost, streamReceiverPort, numSentences, bytes, intervalTime) + /*println("Sleeping for 5 seconds")*/ + /*Thread.sleep(5000)*/ + /*System.gc()*/ + /*}*/ + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/TestGenerator4.scala b/streaming/src/main/scala/spark/streaming/TestGenerator4.scala new file mode 100644 index 0000000000..93c7f2f440 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TestGenerator4.scala @@ -0,0 +1,244 @@ +package spark.streaming + +import spark.Logging + +import scala.util.Random +import scala.io.Source +import scala.collection.mutable.{ArrayBuffer, Queue} + +import java.net._ +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ + +import it.unimi.dsi.fastutil.io._ + +class TestGenerator4(targetHost: String, targetPort: Int, sentenceFile: String, intervalDuration: Long, sentencesPerInterval: Int) +extends Logging { + + class SendingConnectionHandler(host: String, port: Int, generator: TestGenerator4) + extends ConnectionHandler(host, port, true) { + + val buffers = new ArrayBuffer[ByteBuffer] + val newBuffers = new Queue[ByteBuffer] + var activeKey: SelectionKey = null + + def send(buffer: ByteBuffer) { + logDebug("Sending: " + buffer) + newBuffers.synchronized { + newBuffers.enqueue(buffer) + } + selector.wakeup() + buffer.synchronized { + buffer.wait() + } + } + + override def ready(key: SelectionKey) { + logDebug("Ready") + activeKey = key + val channel = key.channel.asInstanceOf[SocketChannel] + channel.register(selector, SelectionKey.OP_WRITE) + generator.startSending() + } + + override def preSelect() { + newBuffers.synchronized { + while(!newBuffers.isEmpty) { + val buffer = newBuffers.dequeue + buffers += buffer + logDebug("Added: " + buffer) + changeInterest(activeKey, SelectionKey.OP_WRITE) + } + } + } + + override def write(key: SelectionKey) { + try { + /*while(true) {*/ + val channel = key.channel.asInstanceOf[SocketChannel] + if (buffers.size > 0) { + val buffer = buffers(0) + val newBuffer = buffer.slice() + newBuffer.limit(math.min(newBuffer.remaining, 32768)) + val bytesWritten = channel.write(newBuffer) + buffer.position(buffer.position + bytesWritten) + if (bytesWritten == 0) return + if (buffer.remaining == 0) { + buffers -= buffer + buffer.synchronized { + buffer.notify() + } + } + /*changeInterest(key, SelectionKey.OP_WRITE)*/ + } else { + changeInterest(key, 0) + } + /*}*/ + } catch { + case e: IOException => { + if (e.toString.contains("pipe") || e.toString.contains("reset")) { + logError("Connection broken") + } else { + logError("Connection error", e) + } + close(key) + } + } + } + + override def close(key: SelectionKey) { + buffers.clear() + super.close(key) + } + } + + initLogging() + + val connectionHandler = new SendingConnectionHandler(targetHost, targetPort, this) + var sendingThread: Thread = null + var sendCount = 0 + val sendBatches = 5 + + def run() { + logInfo("Connection handler started") + connectionHandler.start() + connectionHandler.join() + if (sendingThread != null && !sendingThread.isInterrupted) { + sendingThread.interrupt + } + logInfo("Connection handler stopped") + } + + def startSending() { + sendingThread = new Thread() { + override def run() { + logInfo("STARTING TO SEND") + sendSentences() + logInfo("SENDING STOPPED AFTER " + sendCount) + connectionHandler.interrupt() + } + } + sendingThread.start() + } + + def stopSending() { + sendingThread.interrupt() + } + + def sendSentences() { + logInfo("Reading the file " + sentenceFile) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close() + + val numSentences = if (sentencesPerInterval <= 0) { + lines.length + } else { + sentencesPerInterval + } + + logInfo("Generating sentence buffer") + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take(numSentences).toArray + } else { + (0 until numSentences).map(i => lines(i % lines.length)).toArray + } + + /* + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take((numSentences / sendBatches).toInt).toArray + } else { + (0 until (numSentences/sendBatches)).map(i => lines(i % lines.length)).toArray + }*/ + + + val serializer = new spark.KryoSerializer().newInstance() + val byteStream = new FastByteArrayOutputStream(100 * 1024 * 1024) + serializer.serializeStream(byteStream).writeAll(sentences.toIterator.asInstanceOf[Iterator[Any]]).close() + byteStream.trim() + val sentenceBuffer = ByteBuffer.wrap(byteStream.array) + + logInfo("Sending " + numSentences+ " sentences / " + sentenceBuffer.limit + " bytes per " + intervalDuration + " ms to " + targetHost + ":" + targetPort ) + val currentTime = System.currentTimeMillis + var targetTime = (currentTime / intervalDuration + 1).toLong * intervalDuration + Thread.sleep(targetTime - currentTime) + + val totalBytes = sentenceBuffer.limit + + while(true) { + val batchesInCurrentInterval = sendBatches // if (sendCount < 10) 1 else sendBatches + + val startTime = System.currentTimeMillis() + logDebug("Sending # " + sendCount + " at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") + + (0 until batchesInCurrentInterval).foreach(i => { + try { + val position = (i * totalBytes / sendBatches).toInt + val limit = if (i == sendBatches - 1) { + totalBytes + } else { + ((i + 1) * totalBytes / sendBatches).toInt - 1 + } + + val partStartTime = System.currentTimeMillis + sentenceBuffer.limit(limit) + connectionHandler.send(sentenceBuffer) + val partFinishTime = System.currentTimeMillis + val sleepTime = math.max(0, intervalDuration / sendBatches - (partFinishTime - partStartTime) - 1) + Thread.sleep(sleepTime) + + } catch { + case ie: InterruptedException => return + case e: Exception => e.printStackTrace() + } + }) + sentenceBuffer.rewind() + + val finishTime = System.currentTimeMillis() + /*logInfo ("Sent " + sentenceBuffer.limit + " bytes in " + (finishTime - startTime) + " ms")*/ + targetTime = targetTime + intervalDuration //+ (if (sendCount < 3) 1000 else 0) + + val sleepTime = (targetTime - finishTime) + 20 + if (sleepTime > 0) { + logInfo("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } else { + logInfo("###### Skipping sleep ######") + } + if (Thread.currentThread.isInterrupted) { + return + } + sendCount += 1 + } + } +} + +object TestGenerator4 { + def printUsage { + println("Usage: TestGenerator4 []") + System.exit(0) + } + + def main(args: Array[String]) { + println("GENERATOR STARTED") + if (args.length < 4) { + printUsage + } + + + val streamReceiverHost = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val intervalDuration = args(3).toLong + val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 + + while(true) { + val generator = new TestGenerator4(streamReceiverHost, streamReceiverPort, sentenceFile, intervalDuration, sentencesPerInterval) + generator.run() + Thread.sleep(2000) + } + println("GENERATOR STOPPED") + } +} diff --git a/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala b/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala new file mode 100644 index 0000000000..7e23b7bb82 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala @@ -0,0 +1,42 @@ +package spark.streaming +import spark.Logging +import scala.collection.mutable.{ArrayBuffer, HashMap} + +object TestInputBlockTracker extends Logging { + initLogging() + val allBlockIds = new HashMap[Time, ArrayBuffer[String]]() + + def addBlocks(intervalEndTime: Time, reference: AnyRef) { + allBlockIds.getOrElseUpdate(intervalEndTime, new ArrayBuffer[String]()) ++= reference.asInstanceOf[Array[String]] + } + + def setEndTime(intervalEndTime: Time) { + try { + val endTime = System.currentTimeMillis + allBlockIds.get(intervalEndTime) match { + case Some(blockIds) => { + val numBlocks = blockIds.size + var totalDelay = 0d + blockIds.foreach(blockId => { + val inputTime = getInputTime(blockId) + val delay = (endTime - inputTime) / 1000.0 + totalDelay += delay + logInfo("End-to-end delay for block " + blockId + " is " + delay + " s") + }) + logInfo("Average end-to-end delay for time " + intervalEndTime + " is " + (totalDelay / numBlocks) + " s") + allBlockIds -= intervalEndTime + } + case None => throw new Exception("Unexpected") + } + } catch { + case e: Exception => logError(e.toString) + } + } + + def getInputTime(blockId: String): Long = { + val parts = blockId.split("-") + /*logInfo(blockId + " -> " + parts(4)) */ + parts(4).toLong + } +} + diff --git a/streaming/src/main/scala/spark/streaming/TestStreamCoordinator.scala b/streaming/src/main/scala/spark/streaming/TestStreamCoordinator.scala new file mode 100644 index 0000000000..c658a036f9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TestStreamCoordinator.scala @@ -0,0 +1,38 @@ +package spark.streaming + +import spark.Logging + +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ + +sealed trait TestStreamCoordinatorMessage +case class GetStreamDetails extends TestStreamCoordinatorMessage +case class GotStreamDetails(name: String, duration: Long) extends TestStreamCoordinatorMessage +case class TestStarted extends TestStreamCoordinatorMessage + +class TestStreamCoordinator(streamDetails: Array[(String, Long)]) extends Actor with Logging { + + var index = 0 + + initLogging() + + logInfo("Created") + + def receive = { + case TestStarted => { + sender ! "OK" + } + + case GetStreamDetails => { + val streamDetail = if (index >= streamDetails.length) null else streamDetails(index) + sender ! GotStreamDetails(streamDetail._1, streamDetail._2) + index += 1 + if (streamDetail != null) { + logInfo("Allocated " + streamDetail._1 + " (" + index + "/" + streamDetails.length + ")" ) + } + } + } + +} + diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala new file mode 100644 index 0000000000..a7a5635aa5 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala @@ -0,0 +1,420 @@ +package spark.streaming + +import spark._ +import spark.storage._ +import spark.util.AkkaUtils + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} + +import akka.actor._ +import akka.actor.Actor +import akka.dispatch._ +import akka.pattern.ask +import akka.util.duration._ + +import java.io.DataInputStream +import java.io.BufferedInputStream +import java.net.Socket +import java.net.ServerSocket +import java.util.LinkedHashMap + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +import spark.Utils + + +class TestStreamReceiver3(actorSystem: ActorSystem, blockManager: BlockManager) +extends Thread with Logging { + + + class DataHandler( + inputName: String, + longIntervalDuration: LongTime, + shortIntervalDuration: LongTime, + blockManager: BlockManager + ) + extends Logging { + + class Block(var id: String, var shortInterval: Interval) { + val data = ArrayBuffer[String]() + var pushed = false + def longInterval = getLongInterval(shortInterval) + def empty() = (data.size == 0) + def += (str: String) = (data += str) + override def toString() = "Block " + id + } + + class Bucket(val longInterval: Interval) { + val blocks = new ArrayBuffer[Block]() + var filled = false + def += (block: Block) = blocks += block + def empty() = (blocks.size == 0) + def ready() = (filled && !blocks.exists(! _.pushed)) + def blockIds() = blocks.map(_.id).toArray + override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" + } + + initLogging() + + val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds + val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + + var currentBlock: Block = null + var currentBucket: Bucket = null + + val blocksForPushing = new Queue[Block]() + val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] + + val blockUpdatingThread = new Thread() { override def run() { keepUpdatingCurrentBlock() } } + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + def start() { + blockUpdatingThread.start() + blockPushingThread.start() + } + + def += (data: String) = addData(data) + + def addData(data: String) { + if (currentBlock == null) { + updateCurrentBlock() + } + currentBlock.synchronized { + currentBlock += data + } + } + + def getShortInterval(time: Time): Interval = { + val intervalBegin = time.floor(shortIntervalDuration) + Interval(intervalBegin, intervalBegin + shortIntervalDuration) + } + + def getLongInterval(shortInterval: Interval): Interval = { + val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) + Interval(intervalBegin, intervalBegin + longIntervalDuration) + } + + def updateCurrentBlock() { + /*logInfo("Updating current block")*/ + val currentTime: LongTime = LongTime(System.currentTimeMillis) + val shortInterval = getShortInterval(currentTime) + val longInterval = getLongInterval(shortInterval) + + def createBlock(reuseCurrentBlock: Boolean = false) { + val newBlockId = inputName + "-" + longInterval.toFormattedString + "-" + currentBucket.blocks.size + if (!reuseCurrentBlock) { + val newBlock = new Block(newBlockId, shortInterval) + /*logInfo("Created " + currentBlock)*/ + currentBlock = newBlock + } else { + currentBlock.shortInterval = shortInterval + currentBlock.id = newBlockId + } + } + + def createBucket() { + val newBucket = new Bucket(longInterval) + buckets += ((longInterval, newBucket)) + currentBucket = newBucket + /*logInfo("Created " + currentBucket + ", " + buckets.size + " buckets")*/ + } + + if (currentBlock == null || currentBucket == null) { + createBucket() + currentBucket.synchronized { + createBlock() + } + return + } + + currentBlock.synchronized { + var reuseCurrentBlock = false + + if (shortInterval != currentBlock.shortInterval) { + if (!currentBlock.empty) { + blocksForPushing.synchronized { + blocksForPushing += currentBlock + blocksForPushing.notifyAll() + } + } + + currentBucket.synchronized { + if (currentBlock.empty) { + reuseCurrentBlock = true + } else { + currentBucket += currentBlock + } + + if (longInterval != currentBucket.longInterval) { + currentBucket.filled = true + if (currentBucket.ready) { + currentBucket.notifyAll() + } + createBucket() + } + } + + createBlock(reuseCurrentBlock) + } + } + } + + def pushBlock(block: Block) { + try{ + if (blockManager != null) { + logInfo("Pushing block") + val startTime = System.currentTimeMillis + + val bytes = blockManager.dataSerialize(block.data.toIterator) + val finishTime = System.currentTimeMillis + logInfo(block + " serialization delay is " + (finishTime - startTime) / 1000.0 + " s") + + blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_2) + /*blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ + /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY_DESER)*/ + /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY)*/ + val finishTime1 = System.currentTimeMillis + logInfo(block + " put delay is " + (finishTime1 - startTime) / 1000.0 + " s") + } else { + logWarning(block + " not put as block manager is null") + } + } catch { + case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) + } + } + + def getBucket(longInterval: Interval): Option[Bucket] = { + buckets.get(longInterval) + } + + def clearBucket(longInterval: Interval) { + buckets.remove(longInterval) + } + + def keepUpdatingCurrentBlock() { + logInfo("Thread to update current block started") + while(true) { + updateCurrentBlock() + val currentTimeMillis = System.currentTimeMillis + val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * + shortIntervalDurationMillis - currentTimeMillis + 1 + Thread.sleep(sleepTimeMillis) + } + } + + def keepPushingBlocks() { + var loop = true + logInfo("Thread to push blocks started") + while(loop) { + val block = blocksForPushing.synchronized { + if (blocksForPushing.size == 0) { + blocksForPushing.wait() + } + blocksForPushing.dequeue + } + pushBlock(block) + block.pushed = true + block.data.clear() + + val bucket = buckets(block.longInterval) + bucket.synchronized { + if (bucket.ready) { + bucket.notifyAll() + } + } + } + } + } + + + class ConnectionListener(port: Int, dataHandler: DataHandler) + extends Thread with Logging { + initLogging() + override def run { + try { + val listener = new ServerSocket(port) + logInfo("Listening on port " + port) + while (true) { + new ConnectionHandler(listener.accept(), dataHandler).start(); + } + listener.close() + } catch { + case e: Exception => logError("", e); + } + } + } + + class ConnectionHandler(socket: Socket, dataHandler: DataHandler) extends Thread with Logging { + initLogging() + override def run { + logInfo("New connection from " + socket.getInetAddress() + ":" + socket.getPort) + val bytes = new Array[Byte](100 * 1024 * 1024) + try { + + val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, 1024 * 1024)) + /*val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream))*/ + var str: String = null + str = inputStream.readUTF + while(str != null) { + dataHandler += str + str = inputStream.readUTF() + } + + /* + var loop = true + while(loop) { + val numRead = inputStream.read(bytes) + if (numRead < 0) { + loop = false + } + inbox += ((LongTime(SystemTime.currentTimeMillis), "test")) + }*/ + + inputStream.close() + } catch { + case e => logError("Error receiving data", e) + } + socket.close() + } + } + + initLogging() + + val masterHost = System.getProperty("spark.master.host") + val masterPort = System.getProperty("spark.master.port").toInt + + val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) + val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") + val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") + + logInfo("Getting stream details from master " + masterHost + ":" + masterPort) + + val timeout = 50 millis + + var started = false + while (!started) { + askActor[String](testStreamCoordinator, TestStarted) match { + case Some(str) => { + started = true + logInfo("TestStreamCoordinator started") + } + case None => { + logInfo("TestStreamCoordinator not started yet") + Thread.sleep(200) + } + } + } + + val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { + case Some(details) => details + case None => throw new Exception("Could not get stream details") + } + logInfo("Stream details received: " + streamDetails) + + val inputName = streamDetails.name + val intervalDurationMillis = streamDetails.duration + val intervalDuration = LongTime(intervalDurationMillis) + + val dataHandler = new DataHandler( + inputName, + intervalDuration, + LongTime(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), + blockManager) + + val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) + + // Send a message to an actor and return an option with its reply, or None if this times out + def askActor[T](actor: ActorRef, message: Any): Option[T] = { + try { + val future = actor.ask(message)(timeout) + return Some(Await.result(future, timeout).asInstanceOf[T]) + } catch { + case e: Exception => + logInfo("Error communicating with " + actor, e) + return None + } + } + + override def run() { + connListener.start() + dataHandler.start() + + var interval = Interval.currentInterval(intervalDuration) + var dataStarted = false + + while(true) { + waitFor(interval.endTime) + logInfo("Woken up at " + System.currentTimeMillis + " for " + interval) + dataHandler.getBucket(interval) match { + case Some(bucket) => { + logInfo("Found " + bucket + " for " + interval) + bucket.synchronized { + if (!bucket.ready) { + logInfo("Waiting for " + bucket) + bucket.wait() + logInfo("Wait over for " + bucket) + } + if (dataStarted || !bucket.empty) { + logInfo("Notifying " + bucket) + notifyScheduler(interval, bucket.blockIds) + dataStarted = true + } + bucket.blocks.clear() + dataHandler.clearBucket(interval) + } + } + case None => { + logInfo("Found none for " + interval) + if (dataStarted) { + logInfo("Notifying none") + notifyScheduler(interval, Array[String]()) + } + } + } + interval = interval.next + } + } + + def waitFor(time: Time) { + val currentTimeMillis = System.currentTimeMillis + val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + if (currentTimeMillis < targetTimeMillis) { + val sleepTime = (targetTimeMillis - currentTimeMillis) + Thread.sleep(sleepTime + 1) + } + } + + def notifyScheduler(interval: Interval, blockIds: Array[String]) { + try { + sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) + val time = interval.endTime.asInstanceOf[LongTime] + val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 + logInfo("Pushing delay for " + time + " is " + delay + " s") + } catch { + case _ => logError("Exception notifying scheduler at interval " + interval) + } + } +} + +object TestStreamReceiver3 { + + val PORT = 9999 + val SHORT_INTERVAL_MILLIS = 100 + + def main(args: Array[String]) { + System.setProperty("spark.master.host", Utils.localHostName) + System.setProperty("spark.master.port", "7078") + val details = Array(("Sentences", 2000L)) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) + actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") + new TestStreamReceiver3(actorSystem, null).start() + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala new file mode 100644 index 0000000000..2c3f5d1b9d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala @@ -0,0 +1,373 @@ +package spark.streaming + +import spark._ +import spark.storage._ +import spark.util.AkkaUtils + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} + +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ +import java.util.concurrent.Executors + +import akka.actor._ +import akka.actor.Actor +import akka.dispatch._ +import akka.pattern.ask +import akka.util.duration._ + +class TestStreamReceiver4(actorSystem: ActorSystem, blockManager: BlockManager) +extends Thread with Logging { + + class DataHandler( + inputName: String, + longIntervalDuration: LongTime, + shortIntervalDuration: LongTime, + blockManager: BlockManager + ) + extends Logging { + + class Block(val id: String, val shortInterval: Interval, val buffer: ByteBuffer) { + var pushed = false + def longInterval = getLongInterval(shortInterval) + override def toString() = "Block " + id + } + + class Bucket(val longInterval: Interval) { + val blocks = new ArrayBuffer[Block]() + var filled = false + def += (block: Block) = blocks += block + def empty() = (blocks.size == 0) + def ready() = (filled && !blocks.exists(! _.pushed)) + def blockIds() = blocks.map(_.id).toArray + override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" + } + + initLogging() + + val syncOnLastShortInterval = true + + val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds + val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + + val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) + var currentShortInterval = Interval.currentInterval(shortIntervalDuration) + + val blocksForPushing = new Queue[Block]() + val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] + + val bufferProcessingThread = new Thread() { override def run() { keepProcessingBuffers() } } + val blockPushingExecutor = Executors.newFixedThreadPool(5) + + + def start() { + buffer.clear() + if (buffer.remaining == 0) { + throw new Exception("Buffer initialization error") + } + bufferProcessingThread.start() + } + + def readDataToBuffer(func: ByteBuffer => Int): Int = { + buffer.synchronized { + if (buffer.remaining == 0) { + logInfo("Received first data for interval " + currentShortInterval) + } + func(buffer) + } + } + + def getLongInterval(shortInterval: Interval): Interval = { + val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) + Interval(intervalBegin, intervalBegin + longIntervalDuration) + } + + def processBuffer() { + + def readInt(buffer: ByteBuffer): Int = { + var offset = 0 + var result = 0 + while (offset < 32) { + val b = buffer.get() + result |= ((b & 0x7F) << offset) + if ((b & 0x80) == 0) { + return result + } + offset += 7 + } + throw new Exception("Malformed zigzag-encoded integer") + } + + val currentLongInterval = getLongInterval(currentShortInterval) + val startTime = System.currentTimeMillis + val newBuffer: ByteBuffer = buffer.synchronized { + buffer.flip() + if (buffer.remaining == 0) { + buffer.clear() + null + } else { + logDebug("Processing interval " + currentShortInterval + " with delay of " + (System.currentTimeMillis - startTime) + " ms") + val startTime1 = System.currentTimeMillis + var loop = true + var count = 0 + while(loop) { + buffer.mark() + try { + val len = readInt(buffer) + buffer.position(buffer.position + len) + count += 1 + } catch { + case e: Exception => { + buffer.reset() + loop = false + } + } + } + val bytesToCopy = buffer.position + val newBuf = ByteBuffer.allocate(bytesToCopy) + buffer.position(0) + newBuf.put(buffer.slice().limit(bytesToCopy).asInstanceOf[ByteBuffer]) + newBuf.flip() + buffer.position(bytesToCopy) + buffer.compact() + newBuf + } + } + + if (newBuffer != null) { + val bucket = buckets.getOrElseUpdate(currentLongInterval, new Bucket(currentLongInterval)) + bucket.synchronized { + val newBlockId = inputName + "-" + currentLongInterval.toFormattedString + "-" + currentShortInterval.toFormattedString + val newBlock = new Block(newBlockId, currentShortInterval, newBuffer) + if (syncOnLastShortInterval) { + bucket += newBlock + } + logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.asInstanceOf[LongTime].milliseconds) / 1000.0 + " s" ) + blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) + } + } + + val newShortInterval = Interval.currentInterval(shortIntervalDuration) + val newLongInterval = getLongInterval(newShortInterval) + + if (newLongInterval != currentLongInterval) { + buckets.get(currentLongInterval) match { + case Some(bucket) => { + bucket.synchronized { + bucket.filled = true + if (bucket.ready) { + bucket.notifyAll() + } + } + } + case None => + } + buckets += ((newLongInterval, new Bucket(newLongInterval))) + } + + currentShortInterval = newShortInterval + } + + def pushBlock(block: Block) { + try{ + if (blockManager != null) { + val startTime = System.currentTimeMillis + logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.asInstanceOf[LongTime].milliseconds) + " ms") + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ + blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ + val finishTime = System.currentTimeMillis + logInfo(block + " put delay is " + (finishTime - startTime) + " ms") + } else { + logWarning(block + " not put as block manager is null") + } + } catch { + case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) + } + } + + def getBucket(longInterval: Interval): Option[Bucket] = { + buckets.get(longInterval) + } + + def clearBucket(longInterval: Interval) { + buckets.remove(longInterval) + } + + def keepProcessingBuffers() { + logInfo("Thread to process buffers started") + while(true) { + processBuffer() + val currentTimeMillis = System.currentTimeMillis + val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * + shortIntervalDurationMillis - currentTimeMillis + 1 + Thread.sleep(sleepTimeMillis) + } + } + + def pushAndNotifyBlock(block: Block) { + pushBlock(block) + block.pushed = true + val bucket = if (syncOnLastShortInterval) { + buckets(block.longInterval) + } else { + var longInterval = block.longInterval + while(!buckets.contains(longInterval)) { + logWarning("Skipping bucket of " + longInterval + " for " + block) + longInterval = longInterval.next + } + val chosenBucket = buckets(longInterval) + logDebug("Choosing bucket of " + longInterval + " for " + block) + chosenBucket += block + chosenBucket + } + + bucket.synchronized { + if (bucket.ready) { + bucket.notifyAll() + } + } + + } + } + + + class ReceivingConnectionHandler(host: String, port: Int, dataHandler: DataHandler) + extends ConnectionHandler(host, port, false) { + + override def ready(key: SelectionKey) { + changeInterest(key, SelectionKey.OP_READ) + } + + override def read(key: SelectionKey) { + try { + val channel = key.channel.asInstanceOf[SocketChannel] + val bytesRead = dataHandler.readDataToBuffer(channel.read) + if (bytesRead < 0) { + close(key) + } + } catch { + case e: IOException => { + logError("Error reading", e) + close(key) + } + } + } + } + + initLogging() + + val masterHost = System.getProperty("spark.master.host", "localhost") + val masterPort = System.getProperty("spark.master.port", "7078").toInt + + val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) + val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") + val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") + + logInfo("Getting stream details from master " + masterHost + ":" + masterPort) + + val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { + case Some(details) => details + case None => throw new Exception("Could not get stream details") + } + logInfo("Stream details received: " + streamDetails) + + val inputName = streamDetails.name + val intervalDurationMillis = streamDetails.duration + val intervalDuration = Milliseconds(intervalDurationMillis) + val shortIntervalDuration = Milliseconds(System.getProperty("spark.stream.shortinterval", "500").toInt) + + val dataHandler = new DataHandler(inputName, intervalDuration, shortIntervalDuration, blockManager) + val connectionHandler = new ReceivingConnectionHandler("localhost", 9999, dataHandler) + + val timeout = 100 millis + + // Send a message to an actor and return an option with its reply, or None if this times out + def askActor[T](actor: ActorRef, message: Any): Option[T] = { + try { + val future = actor.ask(message)(timeout) + return Some(Await.result(future, timeout).asInstanceOf[T]) + } catch { + case e: Exception => + logInfo("Error communicating with " + actor, e) + return None + } + } + + override def run() { + connectionHandler.start() + dataHandler.start() + + var interval = Interval.currentInterval(intervalDuration) + var dataStarted = false + + + while(true) { + waitFor(interval.endTime) + /*logInfo("Woken up at " + System.currentTimeMillis + " for " + interval)*/ + dataHandler.getBucket(interval) match { + case Some(bucket) => { + logDebug("Found " + bucket + " for " + interval) + bucket.synchronized { + if (!bucket.ready) { + logDebug("Waiting for " + bucket) + bucket.wait() + logDebug("Wait over for " + bucket) + } + if (dataStarted || !bucket.empty) { + logDebug("Notifying " + bucket) + notifyScheduler(interval, bucket.blockIds) + dataStarted = true + } + bucket.blocks.clear() + dataHandler.clearBucket(interval) + } + } + case None => { + logDebug("Found none for " + interval) + if (dataStarted) { + logDebug("Notifying none") + notifyScheduler(interval, Array[String]()) + } + } + } + interval = interval.next + } + } + + def waitFor(time: Time) { + val currentTimeMillis = System.currentTimeMillis + val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + if (currentTimeMillis < targetTimeMillis) { + val sleepTime = (targetTimeMillis - currentTimeMillis) + Thread.sleep(sleepTime + 1) + } + } + + def notifyScheduler(interval: Interval, blockIds: Array[String]) { + try { + sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) + val time = interval.endTime.asInstanceOf[LongTime] + val delay = (System.currentTimeMillis - time.milliseconds) + logInfo("Notification delay for " + time + " is " + delay + " ms") + } catch { + case e: Exception => logError("Exception notifying scheduler at interval " + interval + ": " + e) + } + } +} + + +object TestStreamReceiver4 { + def main(args: Array[String]) { + val details = Array(("Sentences", 2000L)) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) + actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") + new TestStreamReceiver4(actorSystem, null).start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala new file mode 100644 index 0000000000..b932fe9258 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -0,0 +1,85 @@ +package spark.streaming + +abstract case class Time { + + // basic operations that must be overridden + def copy(): Time + def zero: Time + def < (that: Time): Boolean + def += (that: Time): Time + def -= (that: Time): Time + def floor(that: Time): Time + def isMultipleOf(that: Time): Boolean + + // derived operations composed of basic operations + def + (that: Time) = this.copy() += that + def - (that: Time) = this.copy() -= that + def * (times: Int) = { + var count = 0 + var result = this.copy() + while (count < times) { + result += this + count += 1 + } + result + } + def <= (that: Time) = (this < that || this == that) + def > (that: Time) = !(this <= that) + def >= (that: Time) = !(this < that) + def isZero = (this == zero) + def toFormattedString = toString +} + +object Time { + def Milliseconds(milliseconds: Long) = LongTime(milliseconds) + + def zero = LongTime(0) +} + +case class LongTime(var milliseconds: Long) extends Time { + + override def copy() = LongTime(this.milliseconds) + + override def zero = LongTime(0) + + override def < (that: Time): Boolean = + (this.milliseconds < that.asInstanceOf[LongTime].milliseconds) + + override def += (that: Time): Time = { + this.milliseconds += that.asInstanceOf[LongTime].milliseconds + this + } + + override def -= (that: Time): Time = { + this.milliseconds -= that.asInstanceOf[LongTime].milliseconds + this + } + + override def floor(that: Time): Time = { + val t = that.asInstanceOf[LongTime].milliseconds + val m = this.milliseconds / t + LongTime(m.toLong * t) + } + + override def isMultipleOf(that: Time): Boolean = + (this.milliseconds % that.asInstanceOf[LongTime].milliseconds == 0) + + override def isZero = (this.milliseconds == 0) + + override def toString = (milliseconds.toString + "ms") + + override def toFormattedString = milliseconds.toString +} + +object Milliseconds { + def apply(milliseconds: Long) = LongTime(milliseconds) +} + +object Seconds { + def apply(seconds: Long) = LongTime(seconds * 1000) +} + +object Minutes { + def apply(minutes: Long) = LongTime(minutes * 60000) +} + diff --git a/streaming/src/main/scala/spark/streaming/TopContentCount.scala b/streaming/src/main/scala/spark/streaming/TopContentCount.scala new file mode 100644 index 0000000000..031e989c87 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TopContentCount.scala @@ -0,0 +1,97 @@ +package spark.streaming + +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting + +object TopContentCount { + + case class Event(val country: String, val content: String) + + object Event { + def create(string: String): Event = { + val parts = string.split(":") + new Event(parts(0), parts(1)) + } + } + + def main(args: Array[String]) { + + if (args.length < 2) { + println ("Usage: GrepCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "TopContentCount") + val sc = ssc.sc + val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) + sc.runJob(dummy, (_: Iterator[Int]) => {}) + + + val numEventStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + val eventStrings = new UnifiedRDS( + (1 to numEventStreams).map(i => ssc.readTestStream("Events-" + i, 1000)).toArray + ) + + def parse(string: String) = { + val parts = string.split(":") + (parts(0), parts(1)) + } + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val events = eventStrings.map(x => parse(x)) + /*events.print*/ + + val parallelism = 8 + val counts_per_content_per_country = events + .map(x => (x, 1)) + .reduceByKey(_ + _) + /*.reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), parallelism)*/ + /*counts_per_content_per_country.print*/ + + /* + counts_per_content_per_country.persist( + StorageLevel.MEMORY_ONLY_DESER, + StorageLevel.MEMORY_ONLY_DESER_2, + Seconds(1) + )*/ + + val counts_per_country = counts_per_content_per_country + .map(x => (x._1._1, (x._1._2, x._2))) + .groupByKey() + counts_per_country.print + + + def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { + implicit val countOrdering = new Ordering[(String, Int)] { + override def compare(count1: (String, Int), count2: (String, Int)): Int = { + count2._2 - count1._2 + } + } + val array = data.toArray + Sorting.quickSort(array) + val taken = array.take(k) + taken + } + + val k = 10 + val topKContents_per_country = counts_per_country + .map(x => (x._1, topK(x._2, k))) + .map(x => (x._1, x._2.map(_.toString).reduceLeft(_ + ", " + _))) + + topKContents_per_country.print + + ssc.run + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/TopKWordCount2.scala b/streaming/src/main/scala/spark/streaming/TopKWordCount2.scala new file mode 100644 index 0000000000..679ed0a7ef --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TopKWordCount2.scala @@ -0,0 +1,103 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting + +object TopKWordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) + + def topK(data: Iterator[(String, Int)], k: Int): Iterator[(String, Int)] = { + val taken = new Array[(String, Int)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, Int) = null + var swap: (String, Int) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + println("count = " + count) + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 10 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + + /* + windowedCounts.filter(_ == null).foreachRDD(rdd => { + val count = rdd.count + println("# of nulls = " + count) + })*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/TopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/TopKWordCount2_Special.scala new file mode 100644 index 0000000000..c873fbd0f0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TopKWordCount2_Special.scala @@ -0,0 +1,142 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.Queue + +import java.lang.{Long => JLong} + +object TopKWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "TopKWordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + /*moreWarmup(ssc.sc)*/ + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray + ) + + /*val words = sentences.flatMap(_.split(" "))*/ + + /*def add(v1: Int, v2: Int) = (v1 + v2) */ + /*def subtract(v1: Int, v2: Int) = (v1 - v2) */ + + def add(v1: JLong, v2: JLong) = (v1 + v2) + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } + + + val windowedCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Milliseconds(500), 10) + /*windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1))*/ + windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) + + def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { + val taken = new Array[(String, JLong)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, JLong) = null + var swap: (String, JLong) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + println("count = " + count) + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 50 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + + /* + windowedCounts.filter(_ == null).foreachRDD(rdd => { + val count = rdd.count + println("# of nulls = " + count) + })*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/WindowedRDS.scala b/streaming/src/main/scala/spark/streaming/WindowedRDS.scala new file mode 100644 index 0000000000..812a982301 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WindowedRDS.scala @@ -0,0 +1,68 @@ +package spark.streaming + +import spark.streaming.SparkStreamContext._ + +import spark.RDD +import spark.UnionRDD +import spark.SparkContext._ + +import scala.collection.mutable.ArrayBuffer + +class WindowedRDS[T: ClassManifest]( + parent: RDS[T], + _windowTime: Time, + _slideTime: Time) + extends RDS[T](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of WindowedRDS (" + _slideTime + ") " + + "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of WindowedRDS (" + _slideTime + ") " + + "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") + + val allowPartialWindows = true + + override def dependencies = List(parent) + + def windowTime: Time = _windowTime + + override def slideTime: Time = _slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val parentRDDs = new ArrayBuffer[RDD[T]]() + val windowEndTime = validTime.copy() + val windowStartTime = if (allowPartialWindows && windowEndTime - windowTime < parent.zeroTime) { + parent.zeroTime + } else { + windowEndTime - windowTime + } + + logInfo("Window = " + windowStartTime + " - " + windowEndTime) + logInfo("Parent.zeroTime = " + parent.zeroTime) + + if (windowStartTime >= parent.zeroTime) { + // Walk back through time, from the 'windowEndTime' to 'windowStartTime' + // and get all parent RDDs from the parent RDS + var t = windowEndTime + while (t > windowStartTime) { + parent.getOrCompute(t) match { + case Some(rdd) => parentRDDs += rdd + case None => throw new Exception("Could not generate parent RDD for time " + t) + } + t -= parent.slideTime + } + } + + // Do a union of all parent RDDs to generate the new RDD + if (parentRDDs.size > 0) { + Some(new UnionRDD(ssc.sc, parentRDDs)) + } else { + None + } + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/WordCount.scala b/streaming/src/main/scala/spark/streaming/WordCount.scala new file mode 100644 index 0000000000..fb5508ffcc --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WordCount.scala @@ -0,0 +1,62 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object WordCount { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: WordCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "WordCountWindow") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) + //sentences.print + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), + // System.getProperty("spark.default.parallelism", "1").toInt) + //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) + //windowedCounts.print + + val parallelism = System.getProperty("spark.default.parallelism", "1").toInt + + //val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) + //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) + //val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(_ + _) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), + parallelism) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) + + //windowedCounts.print + windowedCounts.register + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/WordCount1.scala b/streaming/src/main/scala/spark/streaming/WordCount1.scala new file mode 100644 index 0000000000..42d985920a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WordCount1.scala @@ -0,0 +1,46 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object WordCount1 { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: WordCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "WordCountWindow") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + //sentences.print + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/WordCount2.scala b/streaming/src/main/scala/spark/streaming/WordCount2.scala new file mode 100644 index 0000000000..9168a2fe2f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WordCount2.scala @@ -0,0 +1,55 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting + +object WordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + /*moreWarmup(ssc.sc)*/ + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 6) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/WordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/WordCount2_Special.scala new file mode 100644 index 0000000000..1920915af7 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WordCount2_Special.scala @@ -0,0 +1,94 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import java.lang.{Long => JLong} +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + + +object WordCount2_ExtraFunctions { + + def add(v1: JLong, v2: JLong) = (v1 + v2) + + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } +} + +object WordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + + GrepCount2.warmConnectionManagers(ssc.sc) + /*moreWarmup(ssc.sc)*/ + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray + ) + + val windowedCounts = sentences + .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions) + .reduceByKeyAndWindow(WordCount2_ExtraFunctions.add _, WordCount2_ExtraFunctions.subtract _, Seconds(10), Milliseconds(500), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/WordCount3.scala b/streaming/src/main/scala/spark/streaming/WordCount3.scala new file mode 100644 index 0000000000..018c19a509 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WordCount3.scala @@ -0,0 +1,49 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +object WordCount3 { + + def main (args: Array[String]) { + + if (args.length < 1) { + println ("Usage: SparkStreamContext []") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount") + if (args.length > 1) { + ssc.setTempDir(args(1)) + } + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + /*sentences.print*/ + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + /*val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1)*/ + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), 1) + /*windowedCounts.print */ + + def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { + implicit val countOrdering = new Ordering[(String, Int)] { + override def compare(count1: (String, Int), count2: (String, Int)): Int = { + count2._2 - count1._2 + } + } + val array = data.toArray + Sorting.quickSort(array) + array.take(k) + } + + val k = 10 + val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) + topKWindowedCounts.print + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/WordCountEc2.scala b/streaming/src/main/scala/spark/streaming/WordCountEc2.scala new file mode 100644 index 0000000000..82b9fa781d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WordCountEc2.scala @@ -0,0 +1,41 @@ +package spark.streaming + +import SparkStreamContext._ +import spark.SparkContext + +object WordCountEc2 { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: SparkStreamContext ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val ssc = new SparkStreamContext(args(0), "Test") + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + /*sentences.foreach(println)*/ + + val words = sentences.flatMap(_.split(" ")) + /*words.foreach(println)*/ + + val counts = words.map(x => (x, 1)).reduceByKey(_ + _) + /*counts.foreach(println)*/ + + counts.foreachRDD(rdd => rdd.collect.foreach(x => x)) + /*counts.register*/ + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/WordCountTrivialWindow.scala b/streaming/src/main/scala/spark/streaming/WordCountTrivialWindow.scala new file mode 100644 index 0000000000..114dd144f1 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WordCountTrivialWindow.scala @@ -0,0 +1,51 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +object WordCountTrivialWindow { + + def main (args: Array[String]) { + + if (args.length < 1) { + println ("Usage: SparkStreamContext []") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCountTrivialWindow") + if (args.length > 1) { + ssc.setTempDir(args(1)) + } + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + /*sentences.print*/ + + val words = sentences.flatMap(_.split(" ")) + + /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1)*/ + /*counts.print*/ + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1) + /*windowedCounts.print */ + + def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { + implicit val countOrdering = new Ordering[(String, Int)] { + override def compare(count1: (String, Int), count2: (String, Int)): Int = { + count2._2 - count1._2 + } + } + val array = data.toArray + Sorting.quickSort(array) + array.take(k) + } + + val k = 10 + val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) + topKWindowedCounts.print + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/WordMax.scala b/streaming/src/main/scala/spark/streaming/WordMax.scala new file mode 100644 index 0000000000..fbfc48030f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WordMax.scala @@ -0,0 +1,64 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object WordMax { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: WordCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "WordCountWindow") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) + //sentences.print + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + def max(v1: Int, v2: Int) = (if (v1 > v2) v1 else v2) + + //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), + // System.getProperty("spark.default.parallelism", "1").toInt) + //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) + //windowedCounts.print + + val parallelism = System.getProperty("spark.default.parallelism", "1").toInt + + val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) + //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER) + localCounts.persist(StorageLevel.MEMORY_ONLY_DESER_2) + val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(max _) + + //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), + // parallelism) + //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) + + //windowedCounts.print + windowedCounts.register + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) + + ssc.run + } +} -- cgit v1.2.3 From 5a26ca4a80a428eb4d7e3407ca496e39ad38c757 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 30 Jul 2012 13:29:13 -0700 Subject: Restructured file locations to separate examples and other programs from core programs. --- .../streaming/DumbTopKWordCount2_Special.scala | 138 -------------------- .../spark/streaming/DumbWordCount2_Special.scala | 92 ------------- .../src/main/scala/spark/streaming/GrepCount.scala | 39 ------ .../main/scala/spark/streaming/GrepCount2.scala | 113 ---------------- .../scala/spark/streaming/GrepCountApprox.scala | 54 -------- .../src/main/scala/spark/streaming/Interval.scala | 20 --- .../main/scala/spark/streaming/JobManager.scala | 127 ++++-------------- .../main/scala/spark/streaming/JobManager2.scala | 37 ------ .../src/main/scala/spark/streaming/Scheduler.scala | 2 +- .../streaming/SenGeneratorForPerformanceTest.scala | 78 ----------- .../scala/spark/streaming/SenderReceiverTest.scala | 63 --------- .../spark/streaming/SentenceFileGenerator.scala | 92 ------------- .../scala/spark/streaming/SentenceGenerator.scala | 103 --------------- .../main/scala/spark/streaming/ShuffleTest.scala | 22 ---- .../scala/spark/streaming/SimpleWordCount.scala | 30 ----- .../scala/spark/streaming/SimpleWordCount2.scala | 51 -------- .../spark/streaming/SimpleWordCount2_Special.scala | 83 ------------ .../scala/spark/streaming/TopContentCount.scala | 97 -------------- .../scala/spark/streaming/TopKWordCount2.scala | 103 --------------- .../spark/streaming/TopKWordCount2_Special.scala | 142 --------------------- .../src/main/scala/spark/streaming/WordCount.scala | 62 --------- .../main/scala/spark/streaming/WordCount1.scala | 46 ------- .../main/scala/spark/streaming/WordCount2.scala | 55 -------- .../scala/spark/streaming/WordCount2_Special.scala | 94 -------------- .../main/scala/spark/streaming/WordCount3.scala | 49 ------- .../main/scala/spark/streaming/WordCountEc2.scala | 41 ------ .../spark/streaming/WordCountTrivialWindow.scala | 51 -------- .../src/main/scala/spark/streaming/WordMax.scala | 64 ---------- .../examples/DumbTopKWordCount2_Special.scala | 138 ++++++++++++++++++++ .../examples/DumbWordCount2_Special.scala | 92 +++++++++++++ .../scala/spark/streaming/examples/GrepCount.scala | 39 ++++++ .../spark/streaming/examples/GrepCount2.scala | 113 ++++++++++++++++ .../spark/streaming/examples/GrepCountApprox.scala | 54 ++++++++ .../spark/streaming/examples/SimpleWordCount.scala | 30 +++++ .../streaming/examples/SimpleWordCount2.scala | 51 ++++++++ .../examples/SimpleWordCount2_Special.scala | 83 ++++++++++++ .../spark/streaming/examples/TopContentCount.scala | 97 ++++++++++++++ .../spark/streaming/examples/TopKWordCount2.scala | 103 +++++++++++++++ .../examples/TopKWordCount2_Special.scala | 142 +++++++++++++++++++++ .../scala/spark/streaming/examples/WordCount.scala | 62 +++++++++ .../spark/streaming/examples/WordCount1.scala | 46 +++++++ .../spark/streaming/examples/WordCount2.scala | 55 ++++++++ .../streaming/examples/WordCount2_Special.scala | 94 ++++++++++++++ .../spark/streaming/examples/WordCount3.scala | 49 +++++++ .../spark/streaming/examples/WordCountEc2.scala | 41 ++++++ .../examples/WordCountTrivialWindow.scala | 51 ++++++++ .../scala/spark/streaming/examples/WordMax.scala | 64 ++++++++++ .../utils/SenGeneratorForPerformanceTest.scala | 78 +++++++++++ .../spark/streaming/utils/SenderReceiverTest.scala | 63 +++++++++ .../streaming/utils/SentenceFileGenerator.scala | 92 +++++++++++++ .../spark/streaming/utils/SentenceGenerator.scala | 103 +++++++++++++++ .../scala/spark/streaming/utils/ShuffleTest.scala | 22 ++++ 52 files changed, 1789 insertions(+), 1921 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/DumbTopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DumbWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/GrepCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/GrepCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/GrepCountApprox.scala delete mode 100644 streaming/src/main/scala/spark/streaming/JobManager2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SenGeneratorForPerformanceTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SenderReceiverTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SentenceFileGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SentenceGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ShuffleTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SimpleWordCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SimpleWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SimpleWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TopContentCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TopKWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WordCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WordCount1.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WordCount3.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WordCountEc2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WordCountTrivialWindow.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WordMax.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount1.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount3.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordMax.scala create mode 100644 streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala diff --git a/streaming/src/main/scala/spark/streaming/DumbTopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/DumbTopKWordCount2_Special.scala deleted file mode 100644 index 2ca72da79f..0000000000 --- a/streaming/src/main/scala/spark/streaming/DumbTopKWordCount2_Special.scala +++ /dev/null @@ -1,138 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object DumbTopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - /*println("count = " + count)*/ - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/DumbWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/DumbWordCount2_Special.scala deleted file mode 100644 index 34e7edfda9..0000000000 --- a/streaming/src/main/scala/spark/streaming/DumbWordCount2_Special.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -object DumbWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - - map.toIterator - } - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/GrepCount.scala b/streaming/src/main/scala/spark/streaming/GrepCount.scala deleted file mode 100644 index ec3e70f258..0000000000 --- a/streaming/src/main/scala/spark/streaming/GrepCount.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: GrepCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/GrepCount2.scala b/streaming/src/main/scala/spark/streaming/GrepCount2.scala deleted file mode 100644 index 27ecced1c0..0000000000 --- a/streaming/src/main/scala/spark/streaming/GrepCount2.scala +++ /dev/null @@ -1,113 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkEnv -import spark.SparkContext -import spark.storage.StorageLevel -import spark.network.Message -import spark.network.ConnectionManagerId - -import java.nio.ByteBuffer - -object GrepCount2 { - - def startSparkEnvs(sc: SparkContext) { - - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - println("SparkEnvs started") - Thread.sleep(1000) - /*sc.runJob(sc.parallelize(0 to 1000, 100), (_: Iterator[Int]) => {})*/ - } - - def warmConnectionManagers(sc: SparkContext) { - val slaveConnManagerIds = sc.parallelize(0 to 100, 100).map( - i => SparkEnv.get.connectionManager.id).collect().distinct - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - Thread.sleep(1000) - val numSlaves = slaveConnManagerIds.size - val count = 3 - val size = 5 * 1024 * 1024 - val iterations = (500 * 1024 * 1024 / (numSlaves * size)).toInt - println("count = " + count + ", size = " + size + ", iterations = " + iterations) - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until numSlaves, numSlaves).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - /*connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - })*/ - - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = (0 until iterations).map(i => { - slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - println("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - }) - }).flatMap(x => x) - val results = futures.map(f => f()) - val finishTime = System.currentTimeMillis - - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" - println(resultStr) - System.gc() - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } - - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "GrepCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - /*startSparkEnvs(ssc.sc)*/ - warmConnectionManagers(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-"+i, 500)).toArray - ) - - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/GrepCountApprox.scala b/streaming/src/main/scala/spark/streaming/GrepCountApprox.scala deleted file mode 100644 index f9674136fe..0000000000 --- a/streaming/src/main/scala/spark/streaming/GrepCountApprox.scala +++ /dev/null @@ -1,54 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCountApprox { - var inputFile : String = null - var hdfs : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 5) { - println ("Usage: GrepCountApprox ") - System.exit(1) - } - - hdfs = args(1) - inputFile = hdfs + args(2) - idealPartitions = args(3).toInt - val timeout = args(4).toLong - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - ssc.setTempDir(hdfs + "/tmp") - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - var i = 0 - val startTime = System.currentTimeMillis - matching.foreachRDD { rdd => - val myNum = i - val result = rdd.countApprox(timeout) - val initialTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tinitial\t%.1f\t%.1f\n", initialTime, myNum, result.initialValue.mean, - result.initialValue.high - result.initialValue.low) - result.onComplete { r => - val finalTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tfinal\t%.1f\t0.0\t%.1f\n", finalTime, myNum, r.mean, finalTime - initialTime) - } - i += 1 - } - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index a985f44ba1..9a61d85274 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -42,26 +42,6 @@ case class Interval (val beginTime: Time, val endTime: Time) { object Interval { - /* - implicit def longTupleToInterval (longTuple: (Long, Long)) = - Interval(longTuple._1, longTuple._2) - - implicit def intTupleToInterval (intTuple: (Int, Int)) = - Interval(intTuple._1, intTuple._2) - - implicit def string2Interval (str: String): Interval = { - val parts = str.split(",") - if (parts.length == 1) - return Interval.zero - return Interval (parts(0).toInt, parts(1).toInt) - } - - def getInterval (timeMs: Long, intervalDurationMs: Long): Interval = { - val intervalBeginMs = timeMs / intervalDurationMs * intervalDurationMs - Interval(intervalBeginMs, intervalBeginMs + intervalDurationMs) - } - */ - def zero() = new Interval (Time.zero, Time.zero) def currentInterval(intervalDuration: LongTime): Interval = { diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 45a3971643..d7d88a7000 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -1,112 +1,37 @@ package spark.streaming -import spark.SparkEnv -import spark.Logging - -import scala.collection.mutable.PriorityQueue -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ -import scala.actors.scheduler.ResizableThreadPoolScheduler -import scala.actors.scheduler.ForkJoinScheduler - -sealed trait JobManagerMessage -case class RunJob(job: Job) extends JobManagerMessage -case class JobCompleted(handlerId: Int) extends JobManagerMessage - -class JobHandler(ssc: SparkStreamContext, val id: Int) extends DaemonActor with Logging { - - var busy = false - - def act() { - loop { - receive { - case job: Job => { - SparkEnv.set(ssc.env) - try { - logInfo("Starting " + job) - job.run() - logInfo("Finished " + job) - if (job.time.isInstanceOf[LongTime]) { - val longTime = job.time.asInstanceOf[LongTime] - logInfo("Total pushing + skew + processing delay for " + longTime + " is " + - (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") - } - } catch { - case e: Exception => logError("SparkStream job failed", e) +import spark.{Logging, SparkEnv} +import java.util.concurrent.Executors + + +class JobManager(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { + + class JobHandler(ssc: SparkStreamContext, job: Job) extends Runnable { + def run() { + SparkEnv.set(ssc.env) + try { + logInfo("Starting " + job) + job.run() + logInfo("Finished " + job) + if (job.time.isInstanceOf[LongTime]) { + val longTime = job.time.asInstanceOf[LongTime] + logInfo("Total notification + skew + processing delay for " + longTime + " is " + + (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") + if (System.getProperty("spark.stream.distributed", "false") == "true") { + TestInputBlockTracker.setEndTime(job.time) } - busy = false - reply(JobCompleted(id)) } + } catch { + case e: Exception => logError("SparkStream job failed", e) } } } -} -class JobManager(ssc: SparkStreamContext, numThreads: Int = 2) extends DaemonActor with Logging { + initLogging() - implicit private val jobOrdering = new Ordering[Job] { - override def compare(job1: Job, job2: Job): Int = { - if (job1.time < job2.time) { - return 1 - } else if (job2.time < job1.time) { - return -1 - } else { - return 0 - } - } - } - - private val jobs = new PriorityQueue[Job]() - private val handlers = (0 until numThreads).map(i => new JobHandler(ssc, i)) - - def act() { - handlers.foreach(_.start) - loop { - receive { - case RunJob(job) => { - jobs += job - logInfo("Job " + job + " submitted") - runJob() - } - case JobCompleted(handlerId) => { - runJob() - } - } - } - } - - def runJob(): Unit = { - logInfo("Attempting to allocate job ") - if (jobs.size > 0) { - handlers.find(!_.busy).foreach(handler => { - val job = jobs.dequeue - logInfo("Allocating job " + job + " to handler " + handler.id) - handler.busy = true - handler ! job - }) - } + val jobExecutor = Executors.newFixedThreadPool(numThreads) + + def runJob(job: Job) { + jobExecutor.execute(new JobHandler(ssc, job)) } } - -object JobManager { - def main(args: Array[String]) { - val ssc = new SparkStreamContext("local[4]", "JobManagerTest") - val jobManager = new JobManager(ssc) - jobManager.start() - - val t = System.currentTimeMillis - for (i <- 1 to 10) { - jobManager ! RunJob(new Job( - LongTime(i), - () => { - Thread.sleep(500) - println("Job " + i + " took " + (System.currentTimeMillis - t) + " ms") - } - )) - } - Thread.sleep(6000) - } -} - diff --git a/streaming/src/main/scala/spark/streaming/JobManager2.scala b/streaming/src/main/scala/spark/streaming/JobManager2.scala deleted file mode 100644 index ce0154e19b..0000000000 --- a/streaming/src/main/scala/spark/streaming/JobManager2.scala +++ /dev/null @@ -1,37 +0,0 @@ -package spark.streaming - -import spark.{Logging, SparkEnv} -import java.util.concurrent.Executors - - -class JobManager2(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { - - class JobHandler(ssc: SparkStreamContext, job: Job) extends Runnable { - def run() { - SparkEnv.set(ssc.env) - try { - logInfo("Starting " + job) - job.run() - logInfo("Finished " + job) - if (job.time.isInstanceOf[LongTime]) { - val longTime = job.time.asInstanceOf[LongTime] - logInfo("Total notification + skew + processing delay for " + longTime + " is " + - (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") - if (System.getProperty("spark.stream.distributed", "false") == "true") { - TestInputBlockTracker.setEndTime(job.time) - } - } - } catch { - case e: Exception => logError("SparkStream job failed", e) - } - } - } - - initLogging() - - val jobExecutor = Executors.newFixedThreadPool(numThreads) - - def runJob(job: Job) { - jobExecutor.execute(new JobHandler(ssc, job)) - } -} diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 4137d8f27d..8df346559c 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -46,7 +46,7 @@ extends Actor with Logging { val inputNames = inputRDSs.map(_.inputName).toArray val inputStates = new HashMap[Interval, InputState]() val currentJobs = System.getProperty("spark.stream.currentJobs", "1").toInt - val jobManager = new JobManager2(ssc, currentJobs) + val jobManager = new JobManager(ssc, currentJobs) // TODO(Haoyuan): The following line is for performance test only. var cnt: Int = System.getProperty("spark.stream.fake.cnt", "60").toInt diff --git a/streaming/src/main/scala/spark/streaming/SenGeneratorForPerformanceTest.scala b/streaming/src/main/scala/spark/streaming/SenGeneratorForPerformanceTest.scala deleted file mode 100644 index bb32089ae2..0000000000 --- a/streaming/src/main/scala/spark/streaming/SenGeneratorForPerformanceTest.scala +++ /dev/null @@ -1,78 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - -/*import akka.actor.Actor._*/ -/*import akka.actor.ActorRef*/ - - -object SenGeneratorForPerformanceTest { - - def printUsage () { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val inputManagerIP = args(0) - val inputManagerPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = { - if (args.length > 3) args(3).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - try { - /*val inputManager = remote.actorFor("InputReceiver-Sentences",*/ - /* inputManagerIP, inputManagerPort)*/ - val inputManager = select(Node(inputManagerIP, inputManagerPort), Symbol("InputReceiver-Sentences")) - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - println ("Sending " + sentencesPerSecond + " sentences per second to " + inputManagerIP + ":" + inputManagerPort) - var lastPrintTime = System.currentTimeMillis() - var count = 0 - - while (true) { - /*if (!inputManager.tryTell (lines (random.nextInt (lines.length))))*/ - /*throw new Exception ("disconnected")*/ -// inputManager ! lines (random.nextInt (lines.length)) - for (i <- 0 to sentencesPerSecond) inputManager ! lines (0) - println(System.currentTimeMillis / 1000 + " s") -/* count += 1 - - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - - Thread.sleep (sleepBetweenSentences.toLong) -*/ - val currentMs = System.currentTimeMillis / 1000; - Thread.sleep ((currentMs * 1000 + 1000) - System.currentTimeMillis) - } - } catch { - case e: Exception => - /*Thread.sleep (1000)*/ - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/SenderReceiverTest.scala deleted file mode 100644 index 6af270298a..0000000000 --- a/streaming/src/main/scala/spark/streaming/SenderReceiverTest.scala +++ /dev/null @@ -1,63 +0,0 @@ -package spark.streaming -import java.net.{Socket, ServerSocket} -import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} - -object Receiver { - def main(args: Array[String]) { - val port = args(0).toInt - val lsocket = new ServerSocket(port) - println("Listening on port " + port ) - while(true) { - val socket = lsocket.accept() - (new Thread() { - override def run() { - val buffer = new Array[Byte](100000) - var count = 0 - val time = System.currentTimeMillis - try { - val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) - var loop = true - var string: String = null - while((string = is.readUTF) != null) { - count += 28 - } - } catch { - case e: Exception => e.printStackTrace - } - val timeTaken = System.currentTimeMillis - time - val tput = (count / 1024.0) / (timeTaken / 1000.0) - println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") - } - }).start() - } - } - -} - -object Sender { - - def main(args: Array[String]) { - try { - val host = args(0) - val port = args(1).toInt - val size = args(2).toInt - - val byteStream = new ByteArrayOutputStream() - val stringDataStream = new DataOutputStream(byteStream) - (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) - val bytes = byteStream.toByteArray() - println("Generated array of " + bytes.length + " bytes") - - /*val bytes = new Array[Byte](size)*/ - val socket = new Socket(host, port) - val os = socket.getOutputStream - os.write(bytes) - os.flush - socket.close() - - } catch { - case e: Exception => e.printStackTrace - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/SentenceFileGenerator.scala deleted file mode 100644 index 15858f59e3..0000000000 --- a/streaming/src/main/scala/spark/streaming/SentenceFileGenerator.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.streaming - -import spark._ - -import scala.collection.mutable.ArrayBuffer -import scala.util.Random -import scala.io.Source - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -object SentenceFileGenerator { - - def printUsage () { - println ("Usage: SentenceFileGenerator <# partitions> []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 4) { - printUsage - } - - val master = args(0) - val fs = new Path(args(1)).getFileSystem(new Configuration()) - val targetDirectory = new Path(args(1)).makeQualified(fs) - val numPartitions = args(2).toInt - val sentenceFile = args(3) - val sentencesPerSecond = { - if (args.length > 4) args(4).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n").toArray - source.close () - println("Read " + lines.length + " lines from file " + sentenceFile) - - val sentences = { - val buffer = ArrayBuffer[String]() - val random = new Random() - var i = 0 - while (i < sentencesPerSecond) { - buffer += lines(random.nextInt(lines.length)) - i += 1 - } - buffer.toArray - } - println("Generated " + sentences.length + " sentences") - - val sc = new SparkContext(master, "SentenceFileGenerator") - val sentencesRDD = sc.parallelize(sentences, numPartitions) - - val tempDirectory = new Path(targetDirectory, "_tmp") - - fs.mkdirs(targetDirectory) - fs.mkdirs(tempDirectory) - - var saveTimeMillis = System.currentTimeMillis - try { - while (true) { - val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) - val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) - println("Writing to file " + newDir) - sentencesRDD.saveAsTextFile(tmpNewDir.toString) - fs.rename(tmpNewDir, newDir) - saveTimeMillis += 1000 - val sleepTimeMillis = { - val currentTimeMillis = System.currentTimeMillis - if (saveTimeMillis < currentTimeMillis) { - 0 - } else { - saveTimeMillis - currentTimeMillis - } - } - println("Sleeping for " + sleepTimeMillis + " ms") - Thread.sleep(sleepTimeMillis) - } - } catch { - case e: Exception => - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/SentenceGenerator.scala b/streaming/src/main/scala/spark/streaming/SentenceGenerator.scala deleted file mode 100644 index a9f124d2d7..0000000000 --- a/streaming/src/main/scala/spark/streaming/SentenceGenerator.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - - -object SentenceGenerator { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - - try { - var lastPrintTime = System.currentTimeMillis() - var count = 0 - while(true) { - streamReceiver ! lines(random.nextInt(lines.length)) - count += 1 - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - Thread.sleep(sleepBetweenSentences.toLong) - } - } catch { - case e: Exception => - } - } - - def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - try { - val numSentences = if (sentencesPerSecond <= 0) { - lines.length - } else { - sentencesPerSecond - } - var nextSendingTime = System.currentTimeMillis() - val pingInterval = if (System.getenv("INTERVAL") != null) { - System.getenv("INTERVAL").toInt - } else { - 2000 - } - while(true) { - (0 until numSentences).foreach(i => { - streamReceiver ! lines(i % lines.length) - }) - println ("Sent " + numSentences + " sentences") - nextSendingTime += pingInterval - val sleepTime = nextSendingTime - System.currentTimeMillis - if (sleepTime > 0) { - println ("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } - } - } catch { - case e: Exception => - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val generateRandomly = false - - val streamReceiverIP = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 - val sentenceInputName = if (args.length > 4) args(4) else "Sentences" - - println("Sending " + sentencesPerSecond + " sentences per second to " + - streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - val streamReceiver = select( - Node(streamReceiverIP, streamReceiverPort), - Symbol("NetworkStreamReceiver-" + sentenceInputName)) - if (generateRandomly) { - generateRandomSentences(lines, sentencesPerSecond, streamReceiver) - } else { - generateSameSentences(lines, sentencesPerSecond, streamReceiver) - } - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/ShuffleTest.scala deleted file mode 100644 index 32aa4144a0..0000000000 --- a/streaming/src/main/scala/spark/streaming/ShuffleTest.scala +++ /dev/null @@ -1,22 +0,0 @@ -package spark.streaming -import spark.SparkContext -import SparkContext._ - -object ShuffleTest { - def main(args: Array[String]) { - - if (args.length < 1) { - println ("Usage: ShuffleTest ") - System.exit(1) - } - - val sc = new spark.SparkContext(args(0), "ShuffleTest") - val rdd = sc.parallelize(1 to 1000, 500).cache - - def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } - - time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } - System.exit(0) - } -} - diff --git a/streaming/src/main/scala/spark/streaming/SimpleWordCount.scala b/streaming/src/main/scala/spark/streaming/SimpleWordCount.scala deleted file mode 100644 index a75ccd3a56..0000000000 --- a/streaming/src/main/scala/spark/streaming/SimpleWordCount.scala +++ /dev/null @@ -1,30 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1) - counts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/SimpleWordCount2.scala b/streaming/src/main/scala/spark/streaming/SimpleWordCount2.scala deleted file mode 100644 index 9672e64b13..0000000000 --- a/streaming/src/main/scala/spark/streaming/SimpleWordCount2.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - /*words.foreachRDD(_.countByValue())*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/SimpleWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/SimpleWordCount2_Special.scala deleted file mode 100644 index 503033a8e5..0000000000 --- a/streaming/src/main/scala/spark/streaming/SimpleWordCount2_Special.scala +++ /dev/null @@ -1,83 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.collection.JavaConversions.mapAsScalaMap -import scala.util.Sorting -import java.lang.{Long => JLong} - -object SimpleWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 400)).toArray - ) - - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - /*val words = sentences.flatMap(_.split(" "))*/ - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10)*/ - val counts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/TopContentCount.scala b/streaming/src/main/scala/spark/streaming/TopContentCount.scala deleted file mode 100644 index 031e989c87..0000000000 --- a/streaming/src/main/scala/spark/streaming/TopContentCount.scala +++ /dev/null @@ -1,97 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopContentCount { - - case class Event(val country: String, val content: String) - - object Event { - def create(string: String): Event = { - val parts = string.split(":") - new Event(parts(0), parts(1)) - } - } - - def main(args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopContentCount") - val sc = ssc.sc - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - - val numEventStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - val eventStrings = new UnifiedRDS( - (1 to numEventStreams).map(i => ssc.readTestStream("Events-" + i, 1000)).toArray - ) - - def parse(string: String) = { - val parts = string.split(":") - (parts(0), parts(1)) - } - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val events = eventStrings.map(x => parse(x)) - /*events.print*/ - - val parallelism = 8 - val counts_per_content_per_country = events - .map(x => (x, 1)) - .reduceByKey(_ + _) - /*.reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), parallelism)*/ - /*counts_per_content_per_country.print*/ - - /* - counts_per_content_per_country.persist( - StorageLevel.MEMORY_ONLY_DESER, - StorageLevel.MEMORY_ONLY_DESER_2, - Seconds(1) - )*/ - - val counts_per_country = counts_per_content_per_country - .map(x => (x._1._1, (x._1._2, x._2))) - .groupByKey() - counts_per_country.print - - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - val taken = array.take(k) - taken - } - - val k = 10 - val topKContents_per_country = counts_per_country - .map(x => (x._1, topK(x._2, k))) - .map(x => (x._1, x._2.map(_.toString).reduceLeft(_ + ", " + _))) - - topKContents_per_country.print - - ssc.run - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/TopKWordCount2.scala b/streaming/src/main/scala/spark/streaming/TopKWordCount2.scala deleted file mode 100644 index 679ed0a7ef..0000000000 --- a/streaming/src/main/scala/spark/streaming/TopKWordCount2.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopKWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - - def topK(data: Iterator[(String, Int)], k: Int): Iterator[(String, Int)] = { - val taken = new Array[(String, Int)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, Int) = null - var swap: (String, Int) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/TopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/TopKWordCount2_Special.scala deleted file mode 100644 index c873fbd0f0..0000000000 --- a/streaming/src/main/scala/spark/streaming/TopKWordCount2_Special.scala +++ /dev/null @@ -1,142 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object TopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopKWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - /*val words = sentences.flatMap(_.split(" "))*/ - - /*def add(v1: Int, v2: Int) = (v1 + v2) */ - /*def subtract(v1: Int, v2: Int) = (v1 - v2) */ - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val windowedCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Milliseconds(500), 10) - /*windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1))*/ - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 50 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/WordCount.scala b/streaming/src/main/scala/spark/streaming/WordCount.scala deleted file mode 100644 index fb5508ffcc..0000000000 --- a/streaming/src/main/scala/spark/streaming/WordCount.scala +++ /dev/null @@ -1,62 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - //val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - //val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(_ + _) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - parallelism) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/WordCount1.scala b/streaming/src/main/scala/spark/streaming/WordCount1.scala deleted file mode 100644 index 42d985920a..0000000000 --- a/streaming/src/main/scala/spark/streaming/WordCount1.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount1 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/WordCount2.scala b/streaming/src/main/scala/spark/streaming/WordCount2.scala deleted file mode 100644 index 9168a2fe2f..0000000000 --- a/streaming/src/main/scala/spark/streaming/WordCount2.scala +++ /dev/null @@ -1,55 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object WordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 6) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/WordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/WordCount2_Special.scala deleted file mode 100644 index 1920915af7..0000000000 --- a/streaming/src/main/scala/spark/streaming/WordCount2_Special.scala +++ /dev/null @@ -1,94 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - - -object WordCount2_ExtraFunctions { - - def add(v1: JLong, v2: JLong) = (v1 + v2) - - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } -} - -object WordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 40).foreach {i => - sc.parallelize(1 to 20000000, 1000) - .map(_ % 1331).map(_.toString) - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - val windowedCounts = sentences - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions) - .reduceByKeyAndWindow(WordCount2_ExtraFunctions.add _, WordCount2_ExtraFunctions.subtract _, Seconds(10), Milliseconds(500), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/WordCount3.scala b/streaming/src/main/scala/spark/streaming/WordCount3.scala deleted file mode 100644 index 018c19a509..0000000000 --- a/streaming/src/main/scala/spark/streaming/WordCount3.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCount3 { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - /*val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1)*/ - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/WordCountEc2.scala b/streaming/src/main/scala/spark/streaming/WordCountEc2.scala deleted file mode 100644 index 82b9fa781d..0000000000 --- a/streaming/src/main/scala/spark/streaming/WordCountEc2.scala +++ /dev/null @@ -1,41 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ -import spark.SparkContext - -object WordCountEc2 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: SparkStreamContext ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val ssc = new SparkStreamContext(args(0), "Test") - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.foreach(println)*/ - - val words = sentences.flatMap(_.split(" ")) - /*words.foreach(println)*/ - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _) - /*counts.foreach(println)*/ - - counts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - /*counts.register*/ - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/WordCountTrivialWindow.scala b/streaming/src/main/scala/spark/streaming/WordCountTrivialWindow.scala deleted file mode 100644 index 114dd144f1..0000000000 --- a/streaming/src/main/scala/spark/streaming/WordCountTrivialWindow.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCountTrivialWindow { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCountTrivialWindow") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1)*/ - /*counts.print*/ - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/WordMax.scala b/streaming/src/main/scala/spark/streaming/WordMax.scala deleted file mode 100644 index fbfc48030f..0000000000 --- a/streaming/src/main/scala/spark/streaming/WordMax.scala +++ /dev/null @@ -1,64 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordMax { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - def max(v1: Int, v2: Int) = (if (v1 > v2) v1 else v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER) - localCounts.persist(StorageLevel.MEMORY_ONLY_DESER_2) - val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(max _) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - // parallelism) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala new file mode 100644 index 0000000000..2ca72da79f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala @@ -0,0 +1,138 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.Queue + +import java.lang.{Long => JLong} + +object DumbTopKWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + + def add(v1: JLong, v2: JLong) = (v1 + v2) + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } + + + val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + wordCounts.persist(StorageLevel.MEMORY_ONLY) + val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) + + def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { + val taken = new Array[(String, JLong)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, JLong) = null + var swap: (String, JLong) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + /*println("count = " + count)*/ + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 10 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + + /* + windowedCounts.filter(_ == null).foreachRDD(rdd => { + val count = rdd.count + println("# of nulls = " + count) + })*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala new file mode 100644 index 0000000000..34e7edfda9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala @@ -0,0 +1,92 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import java.lang.{Long => JLong} +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + +object DumbWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + def add(v1: JLong, v2: JLong) = (v1 + v2) + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + + map.toIterator + } + + val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + wordCounts.persist(StorageLevel.MEMORY_ONLY) + val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala new file mode 100644 index 0000000000..ec3e70f258 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala @@ -0,0 +1,39 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object GrepCount { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: GrepCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "GrepCount") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + //sentences.print + val matching = sentences.filter(_.contains("light")) + matching.foreachRDD(rdd => println(rdd.count)) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala new file mode 100644 index 0000000000..27ecced1c0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala @@ -0,0 +1,113 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkEnv +import spark.SparkContext +import spark.storage.StorageLevel +import spark.network.Message +import spark.network.ConnectionManagerId + +import java.nio.ByteBuffer + +object GrepCount2 { + + def startSparkEnvs(sc: SparkContext) { + + val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) + sc.runJob(dummy, (_: Iterator[Int]) => {}) + + println("SparkEnvs started") + Thread.sleep(1000) + /*sc.runJob(sc.parallelize(0 to 1000, 100), (_: Iterator[Int]) => {})*/ + } + + def warmConnectionManagers(sc: SparkContext) { + val slaveConnManagerIds = sc.parallelize(0 to 100, 100).map( + i => SparkEnv.get.connectionManager.id).collect().distinct + println("\nSlave ConnectionManagerIds") + slaveConnManagerIds.foreach(println) + println + + Thread.sleep(1000) + val numSlaves = slaveConnManagerIds.size + val count = 3 + val size = 5 * 1024 * 1024 + val iterations = (500 * 1024 * 1024 / (numSlaves * size)).toInt + println("count = " + count + ", size = " + size + ", iterations = " + iterations) + + (0 until count).foreach(i => { + val resultStrs = sc.parallelize(0 until numSlaves, numSlaves).map(i => { + val connManager = SparkEnv.get.connectionManager + val thisConnManagerId = connManager.id + /*connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + println("Received [" + msg + "] from [" + id + "]") + None + })*/ + + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + val futures = (0 until iterations).map(i => { + slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + println("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") + connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) + }) + }).flatMap(x => x) + val results = futures.map(f => f()) + val finishTime = System.currentTimeMillis + + + val mb = size * results.size / 1024.0 / 1024.0 + val ms = finishTime - startTime + + val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" + println(resultStr) + System.gc() + resultStr + }).collect() + + println("---------------------") + println("Run " + i) + resultStrs.foreach(println) + println("---------------------") + }) + } + + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: GrepCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "GrepCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + /*startSparkEnvs(ssc.sc)*/ + warmConnectionManagers(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-"+i, 500)).toArray + ) + + val matching = sentences.filter(_.contains("light")) + matching.foreachRDD(rdd => println(rdd.count)) + + ssc.run + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala new file mode 100644 index 0000000000..f9674136fe --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala @@ -0,0 +1,54 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object GrepCountApprox { + var inputFile : String = null + var hdfs : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 5) { + println ("Usage: GrepCountApprox ") + System.exit(1) + } + + hdfs = args(1) + inputFile = hdfs + args(2) + idealPartitions = args(3).toInt + val timeout = args(4).toLong + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "GrepCount") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + ssc.setTempDir(hdfs + "/tmp") + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + //sentences.print + val matching = sentences.filter(_.contains("light")) + var i = 0 + val startTime = System.currentTimeMillis + matching.foreachRDD { rdd => + val myNum = i + val result = rdd.countApprox(timeout) + val initialTime = (System.currentTimeMillis - startTime) / 1000.0 + printf("APPROX\t%.2f\t%d\tinitial\t%.1f\t%.1f\n", initialTime, myNum, result.initialValue.mean, + result.initialValue.high - result.initialValue.low) + result.onComplete { r => + val finalTime = (System.currentTimeMillis - startTime) / 1000.0 + printf("APPROX\t%.2f\t%d\tfinal\t%.1f\t0.0\t%.1f\n", finalTime, myNum, r.mean, finalTime - initialTime) + } + i += 1 + } + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala new file mode 100644 index 0000000000..a75ccd3a56 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala @@ -0,0 +1,30 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +object SimpleWordCount { + + def main (args: Array[String]) { + + if (args.length < 1) { + println ("Usage: SparkStreamContext []") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount") + if (args.length > 1) { + ssc.setTempDir(args(1)) + } + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) + /*sentences.print*/ + + val words = sentences.flatMap(_.split(" ")) + + val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1) + counts.print + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala new file mode 100644 index 0000000000..9672e64b13 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala @@ -0,0 +1,51 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import scala.util.Sorting + +object SimpleWordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SimpleWordCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + + val words = sentences.flatMap(_.split(" ")) + + val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10) + counts.foreachRDD(_.collect()) + /*words.foreachRDD(_.countByValue())*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala new file mode 100644 index 0000000000..503033a8e5 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala @@ -0,0 +1,83 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import scala.collection.JavaConversions.mapAsScalaMap +import scala.util.Sorting +import java.lang.{Long => JLong} + +object SimpleWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SimpleWordCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 400)).toArray + ) + + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } + + + /*val words = sentences.flatMap(_.split(" "))*/ + /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10)*/ + val counts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + counts.foreachRDD(_.collect()) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala b/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala new file mode 100644 index 0000000000..031e989c87 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala @@ -0,0 +1,97 @@ +package spark.streaming + +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting + +object TopContentCount { + + case class Event(val country: String, val content: String) + + object Event { + def create(string: String): Event = { + val parts = string.split(":") + new Event(parts(0), parts(1)) + } + } + + def main(args: Array[String]) { + + if (args.length < 2) { + println ("Usage: GrepCount2 <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "TopContentCount") + val sc = ssc.sc + val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) + sc.runJob(dummy, (_: Iterator[Int]) => {}) + + + val numEventStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + val eventStrings = new UnifiedRDS( + (1 to numEventStreams).map(i => ssc.readTestStream("Events-" + i, 1000)).toArray + ) + + def parse(string: String) = { + val parts = string.split(":") + (parts(0), parts(1)) + } + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val events = eventStrings.map(x => parse(x)) + /*events.print*/ + + val parallelism = 8 + val counts_per_content_per_country = events + .map(x => (x, 1)) + .reduceByKey(_ + _) + /*.reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), parallelism)*/ + /*counts_per_content_per_country.print*/ + + /* + counts_per_content_per_country.persist( + StorageLevel.MEMORY_ONLY_DESER, + StorageLevel.MEMORY_ONLY_DESER_2, + Seconds(1) + )*/ + + val counts_per_country = counts_per_content_per_country + .map(x => (x._1._1, (x._1._2, x._2))) + .groupByKey() + counts_per_country.print + + + def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { + implicit val countOrdering = new Ordering[(String, Int)] { + override def compare(count1: (String, Int), count2: (String, Int)): Int = { + count2._2 - count1._2 + } + } + val array = data.toArray + Sorting.quickSort(array) + val taken = array.take(k) + taken + } + + val k = 10 + val topKContents_per_country = counts_per_country + .map(x => (x._1, topK(x._2, k))) + .map(x => (x._1, x._2.map(_.toString).reduceLeft(_ + ", " + _))) + + topKContents_per_country.print + + ssc.run + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala new file mode 100644 index 0000000000..679ed0a7ef --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala @@ -0,0 +1,103 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting + +object TopKWordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + moreWarmup(ssc.sc) + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) + + def topK(data: Iterator[(String, Int)], k: Int): Iterator[(String, Int)] = { + val taken = new Array[(String, Int)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, Int) = null + var swap: (String, Int) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + println("count = " + count) + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 10 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + + /* + windowedCounts.filter(_ == null).foreachRDD(rdd => { + val count = rdd.count + println("# of nulls = " + count) + })*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala new file mode 100644 index 0000000000..c873fbd0f0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala @@ -0,0 +1,142 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.Queue + +import java.lang.{Long => JLong} + +object TopKWordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "TopKWordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + /*moreWarmup(ssc.sc)*/ + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray + ) + + /*val words = sentences.flatMap(_.split(" "))*/ + + /*def add(v1: Int, v2: Int) = (v1 + v2) */ + /*def subtract(v1: Int, v2: Int) = (v1 - v2) */ + + def add(v1: JLong, v2: JLong) = (v1 + v2) + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } + + + val windowedCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Milliseconds(500), 10) + /*windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1))*/ + windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) + + def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { + val taken = new Array[(String, JLong)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, JLong) = null + var swap: (String, JLong) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + println("count = " + count) + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 50 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + + /* + windowedCounts.filter(_ == null).foreachRDD(rdd => { + val count = rdd.count + println("# of nulls = " + count) + })*/ + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala new file mode 100644 index 0000000000..fb5508ffcc --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala @@ -0,0 +1,62 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object WordCount { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: WordCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "WordCountWindow") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) + //sentences.print + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), + // System.getProperty("spark.default.parallelism", "1").toInt) + //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) + //windowedCounts.print + + val parallelism = System.getProperty("spark.default.parallelism", "1").toInt + + //val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) + //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) + //val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(_ + _) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), + parallelism) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) + + //windowedCounts.print + windowedCounts.register + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala new file mode 100644 index 0000000000..42d985920a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala @@ -0,0 +1,46 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object WordCount1 { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: WordCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "WordCountWindow") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + //sentences.print + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala new file mode 100644 index 0000000000..9168a2fe2f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -0,0 +1,55 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting + +object WordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 20).foreach {i => + sc.parallelize(1 to 20000000, 500) + .map(_ % 100).map(_.toString) + .map(x => (x, 1)).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + if (args.length > 2) { + ssc.setTempDir(args(2)) + } + + GrepCount2.warmConnectionManagers(ssc.sc) + /*moreWarmup(ssc.sc)*/ + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray + ) + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 6) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala new file mode 100644 index 0000000000..1920915af7 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala @@ -0,0 +1,94 @@ +package spark.streaming + +import spark.SparkContext +import SparkContext._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import java.lang.{Long => JLong} +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + + +object WordCount2_ExtraFunctions { + + def add(v1: JLong, v2: JLong) = (v1 + v2) + + def subtract(v1: JLong, v2: JLong) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { + val map = new java.util.HashMap[String, JLong] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.get(w) + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator + } +} + +object WordCount2_Special { + + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext <# sentence streams>") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + + val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 + + GrepCount2.warmConnectionManagers(ssc.sc) + /*moreWarmup(ssc.sc)*/ + + val sentences = new UnifiedRDS( + (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray + ) + + val windowedCounts = sentences + .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions) + .reduceByKeyAndWindow(WordCount2_ExtraFunctions.add _, WordCount2_ExtraFunctions.subtract _, Seconds(10), Milliseconds(500), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) + windowedCounts.foreachRDD(_.collect) + + ssc.run + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala new file mode 100644 index 0000000000..018c19a509 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala @@ -0,0 +1,49 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +object WordCount3 { + + def main (args: Array[String]) { + + if (args.length < 1) { + println ("Usage: SparkStreamContext []") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount") + if (args.length > 1) { + ssc.setTempDir(args(1)) + } + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + /*sentences.print*/ + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + /*val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1)*/ + val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), 1) + /*windowedCounts.print */ + + def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { + implicit val countOrdering = new Ordering[(String, Int)] { + override def compare(count1: (String, Int), count2: (String, Int)): Int = { + count2._2 - count1._2 + } + } + val array = data.toArray + Sorting.quickSort(array) + array.take(k) + } + + val k = 10 + val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) + topKWindowedCounts.print + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala new file mode 100644 index 0000000000..82b9fa781d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala @@ -0,0 +1,41 @@ +package spark.streaming + +import SparkStreamContext._ +import spark.SparkContext + +object WordCountEc2 { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: SparkStreamContext ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val ssc = new SparkStreamContext(args(0), "Test") + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + /*sentences.foreach(println)*/ + + val words = sentences.flatMap(_.split(" ")) + /*words.foreach(println)*/ + + val counts = words.map(x => (x, 1)).reduceByKey(_ + _) + /*counts.foreach(println)*/ + + counts.foreachRDD(rdd => rdd.collect.foreach(x => x)) + /*counts.register*/ + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala new file mode 100644 index 0000000000..114dd144f1 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala @@ -0,0 +1,51 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +object WordCountTrivialWindow { + + def main (args: Array[String]) { + + if (args.length < 1) { + println ("Usage: SparkStreamContext []") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCountTrivialWindow") + if (args.length > 1) { + ssc.setTempDir(args(1)) + } + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) + /*sentences.print*/ + + val words = sentences.flatMap(_.split(" ")) + + /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1)*/ + /*counts.print*/ + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + + val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1) + /*windowedCounts.print */ + + def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { + implicit val countOrdering = new Ordering[(String, Int)] { + override def compare(count1: (String, Int), count2: (String, Int)): Int = { + count2._2 - count1._2 + } + } + val array = data.toArray + Sorting.quickSort(array) + array.take(k) + } + + val k = 10 + val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) + topKWindowedCounts.print + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax.scala new file mode 100644 index 0000000000..fbfc48030f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax.scala @@ -0,0 +1,64 @@ +package spark.streaming + +import SparkStreamContext._ + +import scala.util.Sorting + +import spark.SparkContext +import spark.storage.StorageLevel + +object WordMax { + var inputFile : String = null + var HDFS : String = null + var idealPartitions : Int = 0 + + def main (args: Array[String]) { + + if (args.length != 4) { + println ("Usage: WordCount ") + System.exit(1) + } + + HDFS = args(1) + inputFile = HDFS + args(2) + idealPartitions = args(3).toInt + println ("Input file: " + inputFile) + + val ssc = new SparkStreamContext(args(0), "WordCountWindow") + + SparkContext.idealPartitions = idealPartitions + SparkContext.inputFile = inputFile + + val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) + //sentences.print + + val words = sentences.flatMap(_.split(" ")) + + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + def max(v1: Int, v2: Int) = (if (v1 > v2) v1 else v2) + + //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), + // System.getProperty("spark.default.parallelism", "1").toInt) + //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) + //windowedCounts.print + + val parallelism = System.getProperty("spark.default.parallelism", "1").toInt + + val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) + //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER) + localCounts.persist(StorageLevel.MEMORY_ONLY_DESER_2) + val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(max _) + + //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), + // parallelism) + //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) + + //windowedCounts.print + windowedCounts.register + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) + //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) + + ssc.run + } +} diff --git a/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala b/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala new file mode 100644 index 0000000000..bb32089ae2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala @@ -0,0 +1,78 @@ +package spark.streaming + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + +/*import akka.actor.Actor._*/ +/*import akka.actor.ActorRef*/ + + +object SenGeneratorForPerformanceTest { + + def printUsage () { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val inputManagerIP = args(0) + val inputManagerPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = { + if (args.length > 3) args(3).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + try { + /*val inputManager = remote.actorFor("InputReceiver-Sentences",*/ + /* inputManagerIP, inputManagerPort)*/ + val inputManager = select(Node(inputManagerIP, inputManagerPort), Symbol("InputReceiver-Sentences")) + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + println ("Sending " + sentencesPerSecond + " sentences per second to " + inputManagerIP + ":" + inputManagerPort) + var lastPrintTime = System.currentTimeMillis() + var count = 0 + + while (true) { + /*if (!inputManager.tryTell (lines (random.nextInt (lines.length))))*/ + /*throw new Exception ("disconnected")*/ +// inputManager ! lines (random.nextInt (lines.length)) + for (i <- 0 to sentencesPerSecond) inputManager ! lines (0) + println(System.currentTimeMillis / 1000 + " s") +/* count += 1 + + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + + Thread.sleep (sleepBetweenSentences.toLong) +*/ + val currentMs = System.currentTimeMillis / 1000; + Thread.sleep ((currentMs * 1000 + 1000) - System.currentTimeMillis) + } + } catch { + case e: Exception => + /*Thread.sleep (1000)*/ + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala new file mode 100644 index 0000000000..6af270298a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala @@ -0,0 +1,63 @@ +package spark.streaming +import java.net.{Socket, ServerSocket} +import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} + +object Receiver { + def main(args: Array[String]) { + val port = args(0).toInt + val lsocket = new ServerSocket(port) + println("Listening on port " + port ) + while(true) { + val socket = lsocket.accept() + (new Thread() { + override def run() { + val buffer = new Array[Byte](100000) + var count = 0 + val time = System.currentTimeMillis + try { + val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) + var loop = true + var string: String = null + while((string = is.readUTF) != null) { + count += 28 + } + } catch { + case e: Exception => e.printStackTrace + } + val timeTaken = System.currentTimeMillis - time + val tput = (count / 1024.0) / (timeTaken / 1000.0) + println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") + } + }).start() + } + } + +} + +object Sender { + + def main(args: Array[String]) { + try { + val host = args(0) + val port = args(1).toInt + val size = args(2).toInt + + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) + val bytes = byteStream.toByteArray() + println("Generated array of " + bytes.length + " bytes") + + /*val bytes = new Array[Byte](size)*/ + val socket = new Socket(host, port) + val os = socket.getOutputStream + os.write(bytes) + os.flush + socket.close() + + } catch { + case e: Exception => e.printStackTrace + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala new file mode 100644 index 0000000000..15858f59e3 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala @@ -0,0 +1,92 @@ +package spark.streaming + +import spark._ + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.io.Source + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +object SentenceFileGenerator { + + def printUsage () { + println ("Usage: SentenceFileGenerator <# partitions> []") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val master = args(0) + val fs = new Path(args(1)).getFileSystem(new Configuration()) + val targetDirectory = new Path(args(1)).makeQualified(fs) + val numPartitions = args(2).toInt + val sentenceFile = args(3) + val sentencesPerSecond = { + if (args.length > 4) args(4).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n").toArray + source.close () + println("Read " + lines.length + " lines from file " + sentenceFile) + + val sentences = { + val buffer = ArrayBuffer[String]() + val random = new Random() + var i = 0 + while (i < sentencesPerSecond) { + buffer += lines(random.nextInt(lines.length)) + i += 1 + } + buffer.toArray + } + println("Generated " + sentences.length + " sentences") + + val sc = new SparkContext(master, "SentenceFileGenerator") + val sentencesRDD = sc.parallelize(sentences, numPartitions) + + val tempDirectory = new Path(targetDirectory, "_tmp") + + fs.mkdirs(targetDirectory) + fs.mkdirs(tempDirectory) + + var saveTimeMillis = System.currentTimeMillis + try { + while (true) { + val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) + val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) + println("Writing to file " + newDir) + sentencesRDD.saveAsTextFile(tmpNewDir.toString) + fs.rename(tmpNewDir, newDir) + saveTimeMillis += 1000 + val sleepTimeMillis = { + val currentTimeMillis = System.currentTimeMillis + if (saveTimeMillis < currentTimeMillis) { + 0 + } else { + saveTimeMillis - currentTimeMillis + } + } + println("Sleeping for " + sleepTimeMillis + " ms") + Thread.sleep(sleepTimeMillis) + } + } catch { + case e: Exception => + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala b/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala new file mode 100644 index 0000000000..a9f124d2d7 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala @@ -0,0 +1,103 @@ +package spark.streaming + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + + +object SentenceGenerator { + + def printUsage { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + + def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + + try { + var lastPrintTime = System.currentTimeMillis() + var count = 0 + while(true) { + streamReceiver ! lines(random.nextInt(lines.length)) + count += 1 + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + Thread.sleep(sleepBetweenSentences.toLong) + } + } catch { + case e: Exception => + } + } + + def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + try { + val numSentences = if (sentencesPerSecond <= 0) { + lines.length + } else { + sentencesPerSecond + } + var nextSendingTime = System.currentTimeMillis() + val pingInterval = if (System.getenv("INTERVAL") != null) { + System.getenv("INTERVAL").toInt + } else { + 2000 + } + while(true) { + (0 until numSentences).foreach(i => { + streamReceiver ! lines(i % lines.length) + }) + println ("Sent " + numSentences + " sentences") + nextSendingTime += pingInterval + val sleepTime = nextSendingTime - System.currentTimeMillis + if (sleepTime > 0) { + println ("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } + } + } catch { + case e: Exception => + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val generateRandomly = false + + val streamReceiverIP = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 + val sentenceInputName = if (args.length > 4) args(4) else "Sentences" + + println("Sending " + sentencesPerSecond + " sentences per second to " + + streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + val streamReceiver = select( + Node(streamReceiverIP, streamReceiverPort), + Symbol("NetworkStreamReceiver-" + sentenceInputName)) + if (generateRandomly) { + generateRandomSentences(lines, sentencesPerSecond, streamReceiver) + } else { + generateSameSentences(lines, sentencesPerSecond, streamReceiver) + } + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala new file mode 100644 index 0000000000..32aa4144a0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala @@ -0,0 +1,22 @@ +package spark.streaming +import spark.SparkContext +import SparkContext._ + +object ShuffleTest { + def main(args: Array[String]) { + + if (args.length < 1) { + println ("Usage: ShuffleTest ") + System.exit(1) + } + + val sc = new spark.SparkContext(args(0), "ShuffleTest") + val rdd = sc.parallelize(1 to 1000, 500).cache + + def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } + + time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } + System.exit(0) + } +} + -- cgit v1.2.3 From 3be54c2a8afcb2a3abf1cf22934123fae3419278 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 1 Aug 2012 22:09:27 -0700 Subject: 1. Refactored SparkStreamContext, Scheduler, InputRDS, FileInputRDS and a few other files. 2. Modified Time class to represent milliseconds (long) directly, instead of LongTime. 3. Added new files QueueInputRDS, RecurringTimer, etc. 4. Added RDDSuite as the skeleton for testcases. 5. Added two examples in spark.streaming.examples. 6. Removed all past examples and a few unnecessary files. Moved a number of files to spark.streaming.util. --- core/src/main/scala/spark/Utils.scala | 2 +- .../src/main/scala/spark/streaming/BlockID.scala | 20 -- .../main/scala/spark/streaming/FileInputRDS.scala | 163 +++++++++++ .../scala/spark/streaming/FileStreamReceiver.scala | 70 ----- .../scala/spark/streaming/IdealPerformance.scala | 36 --- .../src/main/scala/spark/streaming/Interval.scala | 6 +- streaming/src/main/scala/spark/streaming/Job.scala | 9 +- .../main/scala/spark/streaming/JobManager.scala | 23 +- .../spark/streaming/NetworkStreamReceiver.scala | 184 ------------- .../scala/spark/streaming/PairRDSFunctions.scala | 72 +++++ .../main/scala/spark/streaming/QueueInputRDS.scala | 36 +++ streaming/src/main/scala/spark/streaming/RDS.scala | 305 ++++++--------------- .../src/main/scala/spark/streaming/Scheduler.scala | 171 ++---------- .../scala/spark/streaming/SparkStreamContext.scala | 150 +++++++--- .../spark/streaming/TestInputBlockTracker.scala | 42 --- .../spark/streaming/TestStreamReceiver3.scala | 18 +- .../spark/streaming/TestStreamReceiver4.scala | 16 +- .../src/main/scala/spark/streaming/Time.scala | 100 ++++--- .../examples/DumbTopKWordCount2_Special.scala | 138 ---------- .../examples/DumbWordCount2_Special.scala | 92 ------- .../spark/streaming/examples/ExampleOne.scala | 41 +++ .../spark/streaming/examples/ExampleTwo.scala | 47 ++++ .../scala/spark/streaming/examples/GrepCount.scala | 39 --- .../spark/streaming/examples/GrepCount2.scala | 113 -------- .../spark/streaming/examples/GrepCountApprox.scala | 54 ---- .../spark/streaming/examples/SimpleWordCount.scala | 30 -- .../streaming/examples/SimpleWordCount2.scala | 51 ---- .../examples/SimpleWordCount2_Special.scala | 83 ------ .../spark/streaming/examples/TopContentCount.scala | 97 ------- .../spark/streaming/examples/TopKWordCount2.scala | 103 ------- .../examples/TopKWordCount2_Special.scala | 142 ---------- .../scala/spark/streaming/examples/WordCount.scala | 62 ----- .../spark/streaming/examples/WordCount1.scala | 46 ---- .../spark/streaming/examples/WordCount2.scala | 55 ---- .../streaming/examples/WordCount2_Special.scala | 94 ------- .../spark/streaming/examples/WordCount3.scala | 49 ---- .../spark/streaming/examples/WordCountEc2.scala | 41 --- .../examples/WordCountTrivialWindow.scala | 51 ---- .../scala/spark/streaming/examples/WordMax.scala | 64 ----- .../spark/streaming/util/RecurringTimer.scala | 52 ++++ .../spark/streaming/util/SenderReceiverTest.scala | 64 +++++ .../streaming/util/SentenceFileGenerator.scala | 92 +++++++ .../scala/spark/streaming/util/ShuffleTest.scala | 23 ++ .../main/scala/spark/streaming/util/Utils.scala | 9 + .../utils/SenGeneratorForPerformanceTest.scala | 78 ------ .../spark/streaming/utils/SenderReceiverTest.scala | 63 ----- .../streaming/utils/SentenceFileGenerator.scala | 92 ------- .../spark/streaming/utils/SentenceGenerator.scala | 103 ------- .../scala/spark/streaming/utils/ShuffleTest.scala | 22 -- .../src/test/scala/spark/streaming/RDSSuite.scala | 65 +++++ 50 files changed, 971 insertions(+), 2607 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/BlockID.scala create mode 100644 streaming/src/main/scala/spark/streaming/FileInputRDS.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala delete mode 100644 streaming/src/main/scala/spark/streaming/IdealPerformance.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala create mode 100644 streaming/src/main/scala/spark/streaming/QueueInputRDS.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount1.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount3.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordMax.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/Utils.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala create mode 100644 streaming/src/test/scala/spark/streaming/RDSSuite.scala diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 5eda1011f9..1d33f7d6b3 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -185,7 +185,7 @@ object Utils { * millisecond. */ def getUsedTimeMs(startTimeMs: Long): String = { - return " " + (System.currentTimeMillis - startTimeMs) + " ms " + return " " + (System.currentTimeMillis - startTimeMs) + " ms" } /** diff --git a/streaming/src/main/scala/spark/streaming/BlockID.scala b/streaming/src/main/scala/spark/streaming/BlockID.scala deleted file mode 100644 index 16aacfda18..0000000000 --- a/streaming/src/main/scala/spark/streaming/BlockID.scala +++ /dev/null @@ -1,20 +0,0 @@ -package spark.streaming - -case class BlockID(sRds: String, sInterval: Interval, sPartition: Int) { - override def toString : String = ( - sRds + BlockID.sConnector + - sInterval.beginTime + BlockID.sConnector + - sInterval.endTime + BlockID.sConnector + - sPartition - ) -} - -object BlockID { - val sConnector = '-' - - def parse(name : String) = BlockID( - name.split(BlockID.sConnector)(0), - new Interval(name.split(BlockID.sConnector)(1).toLong, - name.split(BlockID.sConnector)(2).toLong), - name.split(BlockID.sConnector)(3).toInt) -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/FileInputRDS.scala b/streaming/src/main/scala/spark/streaming/FileInputRDS.scala new file mode 100644 index 0000000000..dde80cd27a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/FileInputRDS.scala @@ -0,0 +1,163 @@ +package spark.streaming + +import spark.SparkContext +import spark.RDD +import spark.BlockRDD +import spark.UnionRDD +import spark.storage.StorageLevel +import spark.streaming._ + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.PathFilter +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + + +class FileInputRDS[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( + ssc: SparkStreamContext, + directory: Path, + filter: PathFilter = FileInputRDS.defaultPathFilter, + newFilesOnly: Boolean = true) + extends InputRDS[(K, V)](ssc) { + + val fs = directory.getFileSystem(new Configuration()) + var lastModTime: Long = 0 + + override def start() { + if (newFilesOnly) { + lastModTime = System.currentTimeMillis() + } else { + lastModTime = 0 + } + } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val newFilter = new PathFilter() { + var latestModTime = 0L + + def accept(path: Path): Boolean = { + + if (!filter.accept(path)) { + return false + } else { + val modTime = fs.getFileStatus(path).getModificationTime() + if (modTime < lastModTime) { + return false + } + if (modTime > latestModTime) { + latestModTime = modTime + } + return true + } + } + } + + val newFiles = fs.listStatus(directory, newFilter) + lastModTime = newFilter.latestModTime + val newRDD = new UnionRDD(ssc.sc, newFiles.map(file => + ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString)) + ) + Some(newRDD) + } +} + +object FileInputRDS { + val defaultPathFilter = new PathFilter { + def accept(path: Path): Boolean = { + val file = path.getName() + if (file.startsWith(".") || file.endsWith("_tmp")) { + return false + } else { + return true + } + } + } +} + +/* +class NetworkInputRDS[T: ClassManifest]( + val networkInputName: String, + val addresses: Array[InetSocketAddress], + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[T](networkInputName, batchDuration, ssc) { + + + // TODO(Haoyuan): This is for the performance test. + @transient var rdd: RDD[T] = null + + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Running initial count to cache fake RDD") + rdd = ssc.sc.textFile(SparkContext.inputFile, + SparkContext.idealPartitions).asInstanceOf[RDD[T]] + val fakeCacheLevel = System.getProperty("spark.fake.cache", "") + if (fakeCacheLevel == "MEMORY_ONLY_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel != "") { + logError("Invalid fake cache level: " + fakeCacheLevel) + System.exit(1) + } + rdd.count() + } + + @transient val references = new HashMap[Time,String] + + override def compute(validTime: Time): Option[RDD[T]] = { + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Returning fake RDD at " + validTime) + return Some(rdd) + } + references.get(validTime) match { + case Some(reference) => + if (reference.startsWith("file") || reference.startsWith("hdfs")) { + logInfo("Reading from file " + reference + " for time " + validTime) + Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) + } else { + logInfo("Getting from BlockManager " + reference + " for time " + validTime) + Some(new BlockRDD(ssc.sc, Array(reference))) + } + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.toString)) + logInfo("Reference added for time " + time + " - " + reference.toString) + } +} + + +class TestInputRDS( + val testInputName: String, + batchDuration: Time, + ssc: SparkStreamContext) +extends InputRDS[String](testInputName, batchDuration, ssc) { + + @transient val references = new HashMap[Time,Array[String]] + + override def compute(validTime: Time): Option[RDD[String]] = { + references.get(validTime) match { + case Some(reference) => + Some(new BlockRDD[String](ssc.sc, reference)) + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.asInstanceOf[Array[String]])) + } +} +*/ diff --git a/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala b/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala deleted file mode 100644 index 92c7cfe00c..0000000000 --- a/streaming/src/main/scala/spark/streaming/FileStreamReceiver.scala +++ /dev/null @@ -1,70 +0,0 @@ -package spark.streaming - -import spark.Logging - -import scala.collection.mutable.HashSet -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -class FileStreamReceiver ( - inputName: String, - rootDirectory: String, - intervalDuration: Long) - extends Logging { - - val pollInterval = 100 - val sparkstreamScheduler = { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt + 1 - RemoteActor.select(Node(host, port), 'SparkStreamScheduler) - } - val directory = new Path(rootDirectory) - val fs = directory.getFileSystem(new Configuration()) - val files = new HashSet[String]() - var time: Long = 0 - - def start() { - fs.mkdirs(directory) - files ++= getFiles() - - actor { - logInfo("Monitoring directory - " + rootDirectory) - while(true) { - testFiles(getFiles()) - Thread.sleep(pollInterval) - } - } - } - - def getFiles(): Iterable[String] = { - fs.listStatus(directory).map(_.getPath.toString) - } - - def testFiles(fileList: Iterable[String]) { - fileList.foreach(file => { - if (!files.contains(file)) { - if (!file.endsWith("_tmp")) { - notifyFile(file) - } - files += file - } - }) - } - - def notifyFile(file: String) { - logInfo("Notifying file " + file) - time += intervalDuration - val interval = Interval(LongTime(time), LongTime(time + intervalDuration)) - sparkstreamScheduler ! InputGenerated(inputName, interval, file) - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/IdealPerformance.scala b/streaming/src/main/scala/spark/streaming/IdealPerformance.scala deleted file mode 100644 index 303d4e7ae6..0000000000 --- a/streaming/src/main/scala/spark/streaming/IdealPerformance.scala +++ /dev/null @@ -1,36 +0,0 @@ -package spark.streaming - -import scala.collection.mutable.Map - -object IdealPerformance { - val base: String = "The medium researcher counts around the pinched troop The empire breaks " + - "Matei Matei announces HY with a theorem " - - def main (args: Array[String]) { - val sentences: String = base * 100000 - - for (i <- 1 to 30) { - val start = System.nanoTime - - val words = sentences.split(" ") - - val pairs = words.map(word => (word, 1)) - - val counts = Map[String, Int]() - - println("Job " + i + " position A at " + (System.nanoTime - start) / 1e9) - - pairs.foreach((pair) => { - var t = counts.getOrElse(pair._1, 0) - counts(pair._1) = t + pair._2 - }) - println("Job " + i + " position B at " + (System.nanoTime - start) / 1e9) - - for ((word, count) <- counts) { - print(word + " " + count + "; ") - } - println - println("Job " + i + " finished in " + (System.nanoTime - start) / 1e9) - } - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index 9a61d85274..1960097216 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -2,7 +2,7 @@ package spark.streaming case class Interval (val beginTime: Time, val endTime: Time) { - def this(beginMs: Long, endMs: Long) = this(new LongTime(beginMs), new LongTime(endMs)) + def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) def duration(): Time = endTime - beginTime @@ -44,8 +44,8 @@ object Interval { def zero() = new Interval (Time.zero, Time.zero) - def currentInterval(intervalDuration: LongTime): Interval = { - val time = LongTime(System.currentTimeMillis) + def currentInterval(intervalDuration: Time): Interval = { + val time = Time(System.currentTimeMillis) val intervalBegin = time.floor(intervalDuration) Interval(intervalBegin, intervalBegin + intervalDuration) } diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala index f7654dff79..36958dafe1 100644 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -1,13 +1,14 @@ package spark.streaming +import spark.streaming.util.Utils + class Job(val time: Time, func: () => _) { val id = Job.getNewId() - - def run() { - func() + def run(): Long = { + Utils.time { func() } } - override def toString = "SparkStream Job " + id + ":" + time + override def toString = "streaming job " + id + " @ " + time } object Job { diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index d7d88a7000..43d167f7db 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -1,6 +1,7 @@ package spark.streaming -import spark.{Logging, SparkEnv} +import spark.Logging +import spark.SparkEnv import java.util.concurrent.Executors @@ -10,19 +11,14 @@ class JobManager(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { def run() { SparkEnv.set(ssc.env) try { - logInfo("Starting " + job) - job.run() - logInfo("Finished " + job) - if (job.time.isInstanceOf[LongTime]) { - val longTime = job.time.asInstanceOf[LongTime] - logInfo("Total notification + skew + processing delay for " + longTime + " is " + - (System.currentTimeMillis - longTime.milliseconds) / 1000.0 + " s") - if (System.getProperty("spark.stream.distributed", "false") == "true") { - TestInputBlockTracker.setEndTime(job.time) - } - } + val timeTaken = job.run() + logInfo( + "Runnning " + job + " took " + timeTaken + " ms, " + + "total delay was " + (System.currentTimeMillis - job.time) + " ms" + ) } catch { - case e: Exception => logError("SparkStream job failed", e) + case e: Exception => + logError("Running " + job + " failed", e) } } } @@ -33,5 +29,6 @@ class JobManager(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { def runJob(job: Job) { jobExecutor.execute(new JobHandler(ssc, job)) + logInfo("Added " + job + " to queue") } } diff --git a/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala b/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala deleted file mode 100644 index efd4689cf0..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkStreamReceiver.scala +++ /dev/null @@ -1,184 +0,0 @@ -package spark.streaming - -import spark.Logging -import spark.storage.StorageLevel - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer} -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.io.BufferedWriter -import java.io.OutputStreamWriter - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -/*import akka.actor.Actor._*/ - -class NetworkStreamReceiver[T: ClassManifest] ( - inputName: String, - intervalDuration: Time, - splitId: Int, - ssc: SparkStreamContext, - tempDirectory: String) - extends DaemonActor - with Logging { - - /** - * Assume all data coming in has non-decreasing timestamp. - */ - final class Inbox[T: ClassManifest] (intervalDuration: Time) { - var currentBucket: (Interval, ArrayBuffer[T]) = null - val filledBuckets = new Queue[(Interval, ArrayBuffer[T])]() - - def += (tuple: (Time, T)) = addTuple(tuple) - - def addTuple(tuple: (Time, T)) { - val (time, data) = tuple - val interval = getInterval (time) - - filledBuckets.synchronized { - if (currentBucket == null) { - currentBucket = (interval, new ArrayBuffer[T]()) - } - - if (interval != currentBucket._1) { - filledBuckets += currentBucket - currentBucket = (interval, new ArrayBuffer[T]()) - } - - currentBucket._2 += data - } - } - - def getInterval(time: Time): Interval = { - val intervalBegin = time.floor(intervalDuration) - Interval (intervalBegin, intervalBegin + intervalDuration) - } - - def hasFilledBuckets(): Boolean = { - filledBuckets.synchronized { - return filledBuckets.size > 0 - } - } - - def popFilledBucket(): (Interval, ArrayBuffer[T]) = { - filledBuckets.synchronized { - if (filledBuckets.size == 0) { - return null - } - return filledBuckets.dequeue() - } - } - } - - val inbox = new Inbox[T](intervalDuration) - lazy val sparkstreamScheduler = { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - val url = "akka://spark@%s:%s/user/SparkStreamScheduler".format(host, port) - ssc.actorSystem.actorFor(url) - } - /*sparkstreamScheduler ! Test()*/ - - val intervalDurationMillis = intervalDuration.asInstanceOf[LongTime].milliseconds - val useBlockManager = true - - initLogging() - - override def act() { - // register the InputReceiver - val port = 7078 - RemoteActor.alive(port) - RemoteActor.register(Symbol("NetworkStreamReceiver-"+inputName), self) - logInfo("Registered actor on port " + port) - - loop { - reactWithin (getSleepTime) { - case TIMEOUT => - flushInbox() - case data => - val t = data.asInstanceOf[T] - inbox += (getTimeFromData(t), t) - } - } - } - - def getSleepTime(): Long = { - (System.currentTimeMillis / intervalDurationMillis + 1) * - intervalDurationMillis - System.currentTimeMillis - } - - def getTimeFromData(data: T): Time = { - LongTime(System.currentTimeMillis) - } - - def flushInbox() { - while (inbox.hasFilledBuckets) { - inbox.synchronized { - val (interval, data) = inbox.popFilledBucket() - val dataArray = data.toArray - logInfo("Received " + dataArray.length + " items at interval " + interval) - val reference = { - if (useBlockManager) { - writeToBlockManager(dataArray, interval) - } else { - writeToDisk(dataArray, interval) - } - } - if (reference != null) { - logInfo("Notifying scheduler") - sparkstreamScheduler ! InputGenerated(inputName, interval, reference.toString) - } - } - } - } - - def writeToDisk(data: Array[T], interval: Interval): String = { - try { - // TODO(Haoyuan): For current test, the following writing to file lines could be - // commented. - val fs = new Path(tempDirectory).getFileSystem(new Configuration()) - val inputDir = new Path( - tempDirectory, - inputName + "-" + interval.toFormattedString) - val inputFile = new Path(inputDir, "part-" + splitId) - logInfo("Writing to file " + inputFile) - if (System.getProperty("spark.fake", "false") != "true") { - val writer = new BufferedWriter(new OutputStreamWriter(fs.create(inputFile, true))) - data.foreach(x => writer.write(x.toString + "\n")) - writer.close() - } else { - logInfo("Fake file") - } - inputFile.toString - }catch { - case e: Exception => - logError("Exception writing to file at interval " + interval + ": " + e.getMessage, e) - null - } - } - - def writeToBlockManager(data: Array[T], interval: Interval): String = { - try{ - val blockId = inputName + "-" + interval.toFormattedString + "-" + splitId - if (System.getProperty("spark.fake", "false") != "true") { - logInfo("Writing as block " + blockId ) - ssc.env.blockManager.put(blockId.toString, data.toIterator, StorageLevel.DISK_AND_MEMORY) - } else { - logInfo("Fake block") - } - blockId - } catch { - case e: Exception => - logError("Exception writing to block manager at interval " + interval + ": " + e.getMessage, e) - null - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala b/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala new file mode 100644 index 0000000000..403ae233a5 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala @@ -0,0 +1,72 @@ +package spark.streaming + +import scala.collection.mutable.ArrayBuffer +import spark.streaming.SparkStreamContext._ + +class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) +extends Serializable { + + def ssc = rds.ssc + + /* ---------------------------------- */ + /* RDS operations for key-value pairs */ + /* ---------------------------------- */ + + def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + def createCombiner(v: V) = ArrayBuffer[V](v) + def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) + def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) + combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) + } + + private def combineByKey[C: ClassManifest]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + numPartitions: Int) : ShuffledRDS[K, V, C] = { + new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def groupByKeyAndWindow( + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { + rds.window(windowTime, slideTime).groupByKey(numPartitions) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledRDS[K, V, V] = { + rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) + } + + // This method is the efficient sliding window reduce operation, + // which requires the specification of an inverse reduce function, + // so that new elements introduced in the window can be "added" using + // reduceFunc to the previous window's result and old elements can be + // "subtracted using invReduceFunc. + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int): ReducedWindowedRDS[K, V] = { + + new ReducedWindowedRDS[K, V]( + rds, + ssc.sc.clean(reduceFunc), + ssc.sc.clean(invReduceFunc), + windowTime, + slideTime, + numPartitions) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala b/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala new file mode 100644 index 0000000000..31e6a64e21 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala @@ -0,0 +1,36 @@ +package spark.streaming + +import spark.RDD +import spark.UnionRDD + +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer + +class QueueInputRDS[T: ClassManifest]( + ssc: SparkStreamContext, + val queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ) extends InputRDS[T](ssc) { + + override def start() { } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val buffer = new ArrayBuffer[RDD[T]]() + if (oneAtATime && queue.size > 0) { + buffer += queue.dequeue() + } else { + buffer ++= queue + } + if (buffer.size > 0) { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } else if (defaultRDD != null) { + Some(defaultRDD) + } else { + None + } + } + +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/RDS.scala b/streaming/src/main/scala/spark/streaming/RDS.scala index c8dd1015ed..fd923929e7 100644 --- a/streaming/src/main/scala/spark/streaming/RDS.scala +++ b/streaming/src/main/scala/spark/streaming/RDS.scala @@ -13,16 +13,18 @@ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import java.net.InetSocketAddress +import java.util.concurrent.ArrayBlockingQueue abstract class RDS[T: ClassManifest] (@transient val ssc: SparkStreamContext) extends Logging with Serializable { initLogging() - /* ---------------------------------------------- */ - /* Methods that must be implemented by subclasses */ - /* ---------------------------------------------- */ + /** + * ---------------------------------------------- + * Methods that must be implemented by subclasses + * ---------------------------------------------- + */ // Time by which the window slides in this RDS def slideTime: Time @@ -33,9 +35,11 @@ extends Logging with Serializable { // Key method that computes RDD for a valid time def compute (validTime: Time): Option[RDD[T]] - /* --------------------------------------- */ - /* Other general fields and methods of RDS */ - /* --------------------------------------- */ + /** + * --------------------------------------- + * Other general fields and methods of RDS + * --------------------------------------- + */ // Variable to store the RDDs generated earlier in time @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () @@ -66,9 +70,9 @@ extends Logging with Serializable { this } + // Set caching level for the RDDs created by this RDS def persist(newLevel: StorageLevel): RDS[T] = persist(newLevel, StorageLevel.NONE, null) - // Turn on the default caching level for this RDD def persist(): RDS[T] = persist(StorageLevel.MEMORY_ONLY_DESER) // Turn on the default caching level for this RDD @@ -76,18 +80,20 @@ extends Logging with Serializable { def isInitialized = (zeroTime != null) - // This method initializes the RDS by setting the "zero" time, based on which - // the validity of future times is calculated. This method also recursively initializes - // its parent RDSs. - def initialize(firstInterval: Interval) { + /** + * This method initializes the RDS by setting the "zero" time, based on which + * the validity of future times is calculated. This method also recursively initializes + * its parent RDSs. + */ + def initialize(time: Time) { if (zeroTime == null) { - zeroTime = firstInterval.beginTime + zeroTime = time } logInfo(this + " initialized") - dependencies.foreach(_.initialize(firstInterval)) + dependencies.foreach(_.initialize(zeroTime)) } - // This method checks whether the 'time' is valid wrt slideTime for generating RDD + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ private def isTimeValid (time: Time): Boolean = { if (!isInitialized) throw new Exception (this.toString + " has not been initialized") @@ -98,11 +104,13 @@ extends Logging with Serializable { } } - // This method either retrieves a precomputed RDD of this RDS, - // or computes the RDD (if the time is valid) + /** + * This method either retrieves a precomputed RDD of this RDS, + * or computes the RDD (if the time is valid) + */ def getOrCompute(time: Time): Option[RDD[T]] = { - - // if RDD was already generated, then retrieve it from HashMap + // If this RDS was not initialized (i.e., zeroTime not set), then do it + // If RDD was already generated, then retrieve it from HashMap generatedRDDs.get(time) match { // If an RDD was already generated and is being reused, then @@ -115,15 +123,12 @@ extends Logging with Serializable { if (isTimeValid(time)) { compute(time) match { case Some(newRDD) => - if (System.getProperty("spark.fake", "false") != "true" || - newRDD.getStorageLevel == StorageLevel.NONE) { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) } generatedRDDs.put(time.copy(), newRDD) Some(newRDD) @@ -136,8 +141,10 @@ extends Logging with Serializable { } } - // This method generates a SparkStream job for the given time - // and may require to be overriden by subclasses + /** + * This method generates a SparkStream job for the given time + * and may require to be overriden by subclasses + */ def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { case Some(rdd) => { @@ -151,9 +158,11 @@ extends Logging with Serializable { } } - /* -------------- */ - /* RDS operations */ - /* -------------- */ + /** + * -------------- + * RDS operations + * -------------- + */ def map[U: ClassManifest](mapFunc: T => U) = new MappedRDS(this, ssc.sc.clean(mapFunc)) @@ -185,6 +194,15 @@ extends Logging with Serializable { newrds } + private[streaming] def toQueue() = { + val queue = new ArrayBlockingQueue[RDD[T]](10000) + this.foreachRDD(rdd => { + println("Added RDD " + rdd.id) + queue.add(rdd) + }) + queue + } + def print() = { def foreachFunc = (rdd: RDD[T], time: Time) => { val first11 = rdd.take(11) @@ -229,198 +247,23 @@ extends Logging with Serializable { } -class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) -extends Serializable { - - def ssc = rds.ssc - - /* ---------------------------------- */ - /* RDS operations for key-value pairs */ - /* ---------------------------------- */ - - def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - def createCombiner(v: V) = ArrayBuffer[V](v) - def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) - def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) - combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) - } - - private def combineByKey[C: ClassManifest]( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - numPartitions: Int) : ShuffledRDS[K, V, C] = { - new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def groupByKeyAndWindow( - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - rds.window(windowTime, slideTime).groupByKey(numPartitions) - } - - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) - } - - // This method is the efficient sliding window reduce operation, - // which requires the specification of an inverse reduce function, - // so that new elements introduced in the window can be "added" using - // reduceFunc to the previous window's result and old elements can be - // "subtracted using invReduceFunc. - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int): ReducedWindowedRDS[K, V] = { - - new ReducedWindowedRDS[K, V]( - rds, - ssc.sc.clean(reduceFunc), - ssc.sc.clean(invReduceFunc), - windowTime, - slideTime, - numPartitions) - } -} - - abstract class InputRDS[T: ClassManifest] ( - val inputName: String, - val batchDuration: Time, ssc: SparkStreamContext) extends RDS[T](ssc) { override def dependencies = List() - override def slideTime = batchDuration + override def slideTime = ssc.batchDuration - def setReference(time: Time, reference: AnyRef) -} - - -class FileInputRDS( - val fileInputName: String, - val directory: String, - ssc: SparkStreamContext) -extends InputRDS[String](fileInputName, LongTime(1000), ssc) { - - @transient val generatedFiles = new HashMap[Time,String] - - // TODO(Haoyuan): This is for the performance test. - @transient - val rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[String]] + def start() - override def compute(validTime: Time): Option[RDD[String]] = { - generatedFiles.get(validTime) match { - case Some(file) => - logInfo("Reading from file " + file + " for time " + validTime) - // Some(ssc.sc.textFile(file).asInstanceOf[RDD[String]]) - // The following line is for HDFS performance test. Sould comment out the above line. - Some(rdd) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - generatedFiles += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } -} - -class NetworkInputRDS[T: ClassManifest]( - val networkInputName: String, - val addresses: Array[InetSocketAddress], - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[T](networkInputName, batchDuration, ssc) { - - - // TODO(Haoyuan): This is for the performance test. - @transient var rdd: RDD[T] = null - - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Running initial count to cache fake RDD") - rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[T]] - val fakeCacheLevel = System.getProperty("spark.fake.cache", "") - if (fakeCacheLevel == "MEMORY_ONLY_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel != "") { - logError("Invalid fake cache level: " + fakeCacheLevel) - System.exit(1) - } - rdd.count() - } - - @transient val references = new HashMap[Time,String] - - override def compute(validTime: Time): Option[RDD[T]] = { - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Returning fake RDD at " + validTime) - return Some(rdd) - } - references.get(validTime) match { - case Some(reference) => - if (reference.startsWith("file") || reference.startsWith("hdfs")) { - logInfo("Reading from file " + reference + " for time " + validTime) - Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) - } else { - logInfo("Getting from BlockManager " + reference + " for time " + validTime) - Some(new BlockRDD(ssc.sc, Array(reference))) - } - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } + def stop() } -class TestInputRDS( - val testInputName: String, - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[String](testInputName, batchDuration, ssc) { - - @transient val references = new HashMap[Time,Array[String]] - - override def compute(validTime: Time): Option[RDD[String]] = { - references.get(validTime) match { - case Some(reference) => - Some(new BlockRDD[String](ssc.sc, reference)) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.asInstanceOf[Array[String]])) - } -} - +/** + * TODO + */ class MappedRDS[T: ClassManifest, U: ClassManifest] ( parent: RDS[T], @@ -437,6 +280,10 @@ extends RDS[U](parent.ssc) { } +/** + * TODO + */ + class FlatMappedRDS[T: ClassManifest, U: ClassManifest]( parent: RDS[T], flatMapFunc: T => Traversable[U]) @@ -452,6 +299,10 @@ extends RDS[U](parent.ssc) { } +/** + * TODO + */ + class FilteredRDS[T: ClassManifest](parent: RDS[T], filterFunc: T => Boolean) extends RDS[T](parent.ssc) { @@ -464,6 +315,11 @@ extends RDS[T](parent.ssc) { } } + +/** + * TODO + */ + class MapPartitionedRDS[T: ClassManifest, U: ClassManifest]( parent: RDS[T], mapPartFunc: Iterator[T] => Iterator[U]) @@ -478,6 +334,11 @@ extends RDS[U](parent.ssc) { } } + +/** + * TODO + */ + class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent.ssc) { override def dependencies = List(parent) @@ -490,6 +351,10 @@ class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent. } +/** + * TODO + */ + class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( parent: RDS[(K,V)], createCombiner: V => C, @@ -519,6 +384,10 @@ class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( } +/** + * TODO + */ + class UnifiedRDS[T: ClassManifest](parents: Array[RDS[T]]) extends RDS[T](parents(0).ssc) { @@ -553,6 +422,10 @@ extends RDS[T](parents(0).ssc) { } +/** + * TODO + */ + class PerElementForEachRDS[T: ClassManifest] ( parent: RDS[T], foreachFunc: T => Unit) @@ -580,6 +453,10 @@ extends RDS[Unit](parent.ssc) { } +/** + * TODO + */ + class PerRDDForEachRDS[T: ClassManifest] ( parent: RDS[T], foreachFunc: (RDD[T], Time) => Unit) diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 8df346559c..83f874e550 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -1,16 +1,11 @@ package spark.streaming +import spark.streaming.util.RecurringTimer import spark.SparkEnv import spark.Logging import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.ArrayBuffer -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ -import akka.util.duration._ sealed trait SchedulerMessage case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage @@ -20,162 +15,42 @@ class Scheduler( ssc: SparkStreamContext, inputRDSs: Array[InputRDS[_]], outputRDSs: Array[RDS[_]]) -extends Actor with Logging { - - class InputState (inputNames: Array[String]) { - val inputsLeft = new HashSet[String]() - inputsLeft ++= inputNames - - val startTime = System.currentTimeMillis - - def delay() = System.currentTimeMillis - startTime - - def addGeneratedInput(inputName: String) = inputsLeft -= inputName - - def areAllInputsGenerated() = (inputsLeft.size == 0) - - override def toString(): String = { - val left = if (inputsLeft.size == 0) "" else inputsLeft.reduceLeft(_ + ", " + _) - return "Inputs left = [ " + left + " ]" - } - } - +extends Logging { initLogging() - val inputNames = inputRDSs.map(_.inputName).toArray - val inputStates = new HashMap[Interval, InputState]() - val currentJobs = System.getProperty("spark.stream.currentJobs", "1").toInt - val jobManager = new JobManager(ssc, currentJobs) - - // TODO(Haoyuan): The following line is for performance test only. - var cnt: Int = System.getProperty("spark.stream.fake.cnt", "60").toInt - var lastInterval: Interval = null - - - /*remote.register("SparkStreamScheduler", actorOf[Scheduler])*/ - logInfo("Registered actor on port ") - - /*jobManager.start()*/ - startStreamReceivers() - - def receive = { - case InputGenerated(inputName, interval, reference) => { - addGeneratedInput(inputName, interval, reference) - } - case Test() => logInfo("TEST PASSED") - } - - def addGeneratedInput(inputName: String, interval: Interval, reference: AnyRef = null) { - logInfo("Input " + inputName + " generated for interval " + interval) - inputStates.get(interval) match { - case None => inputStates.put(interval, new InputState(inputNames)) - case _ => - } - inputStates(interval).addGeneratedInput(inputName) + val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt + val jobManager = new JobManager(ssc, concurrentJobs) + val timer = new RecurringTimer(ssc.batchDuration, generateRDDs(_)) - inputRDSs.filter(_.inputName == inputName).foreach(inputRDS => { - inputRDS.setReference(interval.endTime, reference) - if (inputRDS.isInstanceOf[TestInputRDS]) { - TestInputBlockTracker.addBlocks(interval.endTime, reference) - } - } - ) - - def getNextInterval(): Option[Interval] = { - logDebug("Last interval is " + lastInterval) - val readyIntervals = inputStates.filter(_._2.areAllInputsGenerated).keys - /*inputState.foreach(println) */ - logDebug("InputState has " + inputStates.size + " intervals, " + readyIntervals.size + " ready intervals") - return readyIntervals.find(lastInterval == null || _.beginTime == lastInterval.endTime) - } - - var nextInterval = getNextInterval() - var count = 0 - while(nextInterval.isDefined) { - val inputState = inputStates.get(nextInterval.get).get - generateRDDsForInterval(nextInterval.get) - logInfo("Skew delay for " + nextInterval.get.endTime + " is " + (inputState.delay / 1000.0) + " s") - inputStates.remove(nextInterval.get) - lastInterval = nextInterval.get - nextInterval = getNextInterval() - count += 1 - /*if (nextInterval.size == 0 && inputState.size > 0) { - logDebug("Next interval not ready, pending intervals " + inputState.size) - }*/ - } - logDebug("RDDs generated for " + count + " intervals") - - /* - if (inputState(interval).areAllInputsGenerated) { - generateRDDsForInterval(interval) - lastInterval = interval - inputState.remove(interval) - } else { - logInfo("All inputs not generated for interval " + interval) - } - */ + def start() { + + val zeroTime = Time(timer.start()) + outputRDSs.foreach(_.initialize(zeroTime)) + inputRDSs.par.foreach(_.start()) + logInfo("Scheduler started") } - - def generateRDDsForInterval (interval: Interval) { - logInfo("Generating RDDs for interval " + interval) + + def stop() { + timer.stop() + inputRDSs.par.foreach(_.stop()) + logInfo("Scheduler stopped") + } + + def generateRDDs (time: Time) { + logInfo("Generating RDDs for time " + time) outputRDSs.foreach(outputRDS => { - if (!outputRDS.isInitialized) outputRDS.initialize(interval) - outputRDS.generateJob(interval.endTime) match { + outputRDS.generateJob(time) match { case Some(job) => submitJob(job) case None => } } ) - // TODO(Haoyuan): This comment is for performance test only. - if (System.getProperty("spark.fake", "false") == "true" || System.getProperty("spark.stream.fake", "false") == "true") { - cnt -= 1 - if (cnt <= 0) { - logInfo("My time is up! " + cnt) - System.exit(1) - } - } + logInfo("Generated RDDs for time " + time) } - def submitJob(job: Job) { - logInfo("Submitting " + job + " to JobManager") - /*jobManager ! RunJob(job)*/ + def submitJob(job: Job) { jobManager.runJob(job) } - - def startStreamReceivers() { - val testStreamReceiverNames = new ArrayBuffer[(String, Long)]() - inputRDSs.foreach (inputRDS => { - inputRDS match { - case fileInputRDS: FileInputRDS => { - val fileStreamReceiver = new FileStreamReceiver( - fileInputRDS.inputName, - fileInputRDS.directory, - fileInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds) - fileStreamReceiver.start() - } - case networkInputRDS: NetworkInputRDS[_] => { - val networkStreamReceiver = new NetworkStreamReceiver( - networkInputRDS.inputName, - networkInputRDS.batchDuration, - 0, - ssc, - if (ssc.tempDir == null) null else ssc.tempDir.toString) - networkStreamReceiver.start() - } - case testInputRDS: TestInputRDS => { - testStreamReceiverNames += - ((testInputRDS.inputName, testInputRDS.batchDuration.asInstanceOf[LongTime].milliseconds)) - } - } - }) - if (testStreamReceiverNames.size > 0) { - /*val testStreamCoordinator = new TestStreamCoordinator(testStreamReceiverNames.toArray)*/ - /*testStreamCoordinator.start()*/ - val actor = ssc.actorSystem.actorOf( - Props(new TestStreamCoordinator(testStreamReceiverNames.toArray)), - name = "TestStreamCoordinator") - } - } } diff --git a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala index 51f8193740..d32f6d588c 100644 --- a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala +++ b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala @@ -1,22 +1,23 @@ package spark.streaming -import spark.SparkContext -import spark.SparkEnv -import spark.Utils +import spark.RDD import spark.Logging +import spark.SparkEnv +import spark.SparkContext import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue -import java.net.InetSocketAddress import java.io.IOException -import java.util.UUID +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration - -import akka.actor._ -import akka.actor.Actor -import akka.util.duration._ +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat class SparkStreamContext ( master: String, @@ -24,30 +25,37 @@ class SparkStreamContext ( val sparkHome: String = null, val jars: Seq[String] = Nil) extends Logging { - + initLogging() val sc = new SparkContext(master, frameworkName, sparkHome, jars) val env = SparkEnv.get - val actorSystem = env.actorSystem - - @transient val inputRDSs = new ArrayBuffer[InputRDS[_]]() - @transient val outputRDSs = new ArrayBuffer[RDS[_]]() - var tempDirRoot: String = null - var tempDir: Path = null - - def readNetworkStream[T: ClassManifest]( + val inputRDSs = new ArrayBuffer[InputRDS[_]]() + val outputRDSs = new ArrayBuffer[RDS[_]]() + var batchDuration: Time = null + var scheduler: Scheduler = null + + def setBatchDuration(duration: Long) { + setBatchDuration(Time(duration)) + } + + def setBatchDuration(duration: Time) { + batchDuration = duration + } + + /* + def createNetworkStream[T: ClassManifest]( name: String, addresses: Array[InetSocketAddress], batchDuration: Time): RDS[T] = { - val inputRDS = new NetworkInputRDS[T](name, addresses, batchDuration, this) + val inputRDS = new NetworkInputRDS[T](this, addresses) inputRDSs += inputRDS inputRDS - } + } - def readNetworkStream[T: ClassManifest]( + def createNetworkStream[T: ClassManifest]( name: String, addresses: Array[String], batchDuration: Long): RDS[T] = { @@ -65,40 +73,100 @@ class SparkStreamContext ( addresses.map(stringToInetSocketAddress).toArray, LongTime(batchDuration)) } - - def readFileStream(name: String, directory: String): RDS[String] = { - val path = new Path(directory) - val fs = path.getFileSystem(new Configuration()) - val qualPath = path.makeQualified(fs) - val inputRDS = new FileInputRDS(name, qualPath.toString, this) + */ + + /** + * This function creates a input stream that monitors a Hadoop-compatible + * for new files and executes the necessary processing on them. + */ + def createFileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ](directory: String): RDS[(K, V)] = { + val inputRDS = new FileInputRDS[K, V, F](this, new Path(directory)) inputRDSs += inputRDS inputRDS } - def readTestStream(name: String, batchDuration: Long): RDS[String] = { - val inputRDS = new TestInputRDS(name, LongTime(batchDuration), this) - inputRDSs += inputRDS + def createTextFileStream(directory: String): RDS[String] = { + createFileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) + } + + /** + * This function create a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue + */ + def createQueueStream[T: ClassManifest]( + queue: Queue[RDD[T]], + oneAtATime: Boolean = true, + defaultRDD: RDD[T] = null + ): RDS[T] = { + val inputRDS = new QueueInputRDS(this, queue, oneAtATime, defaultRDD) + inputRDSs += inputRDS inputRDS } + + def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): RDS[T] = { + val queue = new Queue[RDD[T]] + val inputRDS = createQueueStream(queue, true, null) + queue ++= iterator + inputRDS + } + + /** + * This function registers a RDS as an output stream that will be + * computed every interval. + */ def registerOutputStream (outputRDS: RDS[_]) { outputRDSs += outputRDS } - - def setTempDir(dir: String) { - tempDirRoot = dir + + /** + * This function verify whether the stream computation is eligible to be executed. + */ + def verify() { + if (batchDuration == null) { + throw new Exception("Batch duration has not been set") + } + if (batchDuration < Milliseconds(100)) { + logWarning("Batch duration of " + batchDuration + " is very low") + } + if (inputRDSs.size == 0) { + throw new Exception("No input RDSes created, so nothing to take input from") + } + if (outputRDSs.size == 0) { + throw new Exception("No output RDSes registered, so nothing to execute") + } + } - - def run () { - val ctxt = this - val actor = actorSystem.actorOf( - Props(new Scheduler(ctxt, inputRDSs.toArray, outputRDSs.toArray)), - name = "SparkStreamScheduler") - logInfo("Registered actor") - actorSystem.awaitTermination() + + /** + * This function starts the execution of the streams. + */ + def start() { + verify() + scheduler = new Scheduler(this, inputRDSs.toArray, outputRDSs.toArray) + scheduler.start() + } + + /** + * This function starts the execution of the streams. + */ + def stop() { + try { + scheduler.stop() + sc.stop() + } catch { + case e: Exception => logWarning("Error while stopping", e) + } + + logInfo("SparkStreamContext stopped") } } + object SparkStreamContext { implicit def rdsToPairRdsFunctions [K: ClassManifest, V: ClassManifest] (rds: RDS[(K,V)]) = new PairRDSFunctions (rds) diff --git a/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala b/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala deleted file mode 100644 index 7e23b7bb82..0000000000 --- a/streaming/src/main/scala/spark/streaming/TestInputBlockTracker.scala +++ /dev/null @@ -1,42 +0,0 @@ -package spark.streaming -import spark.Logging -import scala.collection.mutable.{ArrayBuffer, HashMap} - -object TestInputBlockTracker extends Logging { - initLogging() - val allBlockIds = new HashMap[Time, ArrayBuffer[String]]() - - def addBlocks(intervalEndTime: Time, reference: AnyRef) { - allBlockIds.getOrElseUpdate(intervalEndTime, new ArrayBuffer[String]()) ++= reference.asInstanceOf[Array[String]] - } - - def setEndTime(intervalEndTime: Time) { - try { - val endTime = System.currentTimeMillis - allBlockIds.get(intervalEndTime) match { - case Some(blockIds) => { - val numBlocks = blockIds.size - var totalDelay = 0d - blockIds.foreach(blockId => { - val inputTime = getInputTime(blockId) - val delay = (endTime - inputTime) / 1000.0 - totalDelay += delay - logInfo("End-to-end delay for block " + blockId + " is " + delay + " s") - }) - logInfo("Average end-to-end delay for time " + intervalEndTime + " is " + (totalDelay / numBlocks) + " s") - allBlockIds -= intervalEndTime - } - case None => throw new Exception("Unexpected") - } - } catch { - case e: Exception => logError(e.toString) - } - } - - def getInputTime(blockId: String): Long = { - val parts = blockId.split("-") - /*logInfo(blockId + " -> " + parts(4)) */ - parts(4).toLong - } -} - diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala index a7a5635aa5..bbf2c7bf5e 100644 --- a/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala +++ b/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala @@ -34,8 +34,8 @@ extends Thread with Logging { class DataHandler( inputName: String, - longIntervalDuration: LongTime, - shortIntervalDuration: LongTime, + longIntervalDuration: Time, + shortIntervalDuration: Time, blockManager: BlockManager ) extends Logging { @@ -61,8 +61,8 @@ extends Thread with Logging { initLogging() - val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds - val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + val shortIntervalDurationMillis = shortIntervalDuration.toLong + val longIntervalDurationMillis = longIntervalDuration.toLong var currentBlock: Block = null var currentBucket: Bucket = null @@ -101,7 +101,7 @@ extends Thread with Logging { def updateCurrentBlock() { /*logInfo("Updating current block")*/ - val currentTime: LongTime = LongTime(System.currentTimeMillis) + val currentTime = Time(System.currentTimeMillis) val shortInterval = getShortInterval(currentTime) val longInterval = getLongInterval(shortInterval) @@ -318,12 +318,12 @@ extends Thread with Logging { val inputName = streamDetails.name val intervalDurationMillis = streamDetails.duration - val intervalDuration = LongTime(intervalDurationMillis) + val intervalDuration = Time(intervalDurationMillis) val dataHandler = new DataHandler( inputName, intervalDuration, - LongTime(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), + Time(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), blockManager) val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) @@ -382,7 +382,7 @@ extends Thread with Logging { def waitFor(time: Time) { val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + val targetTimeMillis = time.milliseconds if (currentTimeMillis < targetTimeMillis) { val sleepTime = (targetTimeMillis - currentTimeMillis) Thread.sleep(sleepTime + 1) @@ -392,7 +392,7 @@ extends Thread with Logging { def notifyScheduler(interval: Interval, blockIds: Array[String]) { try { sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime.asInstanceOf[LongTime] + val time = interval.endTime val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 logInfo("Pushing delay for " + time + " is " + delay + " s") } catch { diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala index 2c3f5d1b9d..a2babb23f4 100644 --- a/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala +++ b/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala @@ -24,8 +24,8 @@ extends Thread with Logging { class DataHandler( inputName: String, - longIntervalDuration: LongTime, - shortIntervalDuration: LongTime, + longIntervalDuration: Time, + shortIntervalDuration: Time, blockManager: BlockManager ) extends Logging { @@ -50,8 +50,8 @@ extends Thread with Logging { val syncOnLastShortInterval = true - val shortIntervalDurationMillis = shortIntervalDuration.asInstanceOf[LongTime].milliseconds - val longIntervalDurationMillis = longIntervalDuration.asInstanceOf[LongTime].milliseconds + val shortIntervalDurationMillis = shortIntervalDuration.milliseconds + val longIntervalDurationMillis = longIntervalDuration.milliseconds val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) var currentShortInterval = Interval.currentInterval(shortIntervalDuration) @@ -145,7 +145,7 @@ extends Thread with Logging { if (syncOnLastShortInterval) { bucket += newBlock } - logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.asInstanceOf[LongTime].milliseconds) / 1000.0 + " s" ) + logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.milliseconds) / 1000.0 + " s" ) blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) } } @@ -175,7 +175,7 @@ extends Thread with Logging { try{ if (blockManager != null) { val startTime = System.currentTimeMillis - logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.asInstanceOf[LongTime].milliseconds) + " ms") + logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.milliseconds) + " ms") /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) @@ -343,7 +343,7 @@ extends Thread with Logging { def waitFor(time: Time) { val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.asInstanceOf[LongTime].milliseconds + val targetTimeMillis = time.milliseconds if (currentTimeMillis < targetTimeMillis) { val sleepTime = (targetTimeMillis - currentTimeMillis) Thread.sleep(sleepTime + 1) @@ -353,7 +353,7 @@ extends Thread with Logging { def notifyScheduler(interval: Interval, blockIds: Array[String]) { try { sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime.asInstanceOf[LongTime] + val time = interval.endTime val delay = (System.currentTimeMillis - time.milliseconds) logInfo("Notification delay for " + time + " is " + delay + " ms") } catch { diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index b932fe9258..c4573137ae 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,19 +1,34 @@ package spark.streaming -abstract case class Time { +class Time(private var millis: Long) { - // basic operations that must be overridden - def copy(): Time - def zero: Time - def < (that: Time): Boolean - def += (that: Time): Time - def -= (that: Time): Time - def floor(that: Time): Time - def isMultipleOf(that: Time): Boolean + def copy() = new Time(this.millis) + + def zero = Time.zero + + def < (that: Time): Boolean = + (this.millis < that.millis) + + def <= (that: Time) = (this < that || this == that) + + def > (that: Time) = !(this <= that) + + def >= (that: Time) = !(this < that) + + def += (that: Time): Time = { + this.millis += that.millis + this + } + + def -= (that: Time): Time = { + this.millis -= that.millis + this + } - // derived operations composed of basic operations def + (that: Time) = this.copy() += that + def - (that: Time) = this.copy() -= that + def * (times: Int) = { var count = 0 var result = this.copy() @@ -23,63 +38,44 @@ abstract case class Time { } result } - def <= (that: Time) = (this < that || this == that) - def > (that: Time) = !(this <= that) - def >= (that: Time) = !(this < that) - def isZero = (this == zero) - def toFormattedString = toString -} - -object Time { - def Milliseconds(milliseconds: Long) = LongTime(milliseconds) - - def zero = LongTime(0) -} - -case class LongTime(var milliseconds: Long) extends Time { - - override def copy() = LongTime(this.milliseconds) - - override def zero = LongTime(0) - - override def < (that: Time): Boolean = - (this.milliseconds < that.asInstanceOf[LongTime].milliseconds) - - override def += (that: Time): Time = { - this.milliseconds += that.asInstanceOf[LongTime].milliseconds - this - } - override def -= (that: Time): Time = { - this.milliseconds -= that.asInstanceOf[LongTime].milliseconds - this + def floor(that: Time): Time = { + val t = that.millis + val m = math.floor(this.millis / t).toLong + new Time(m * t) } - override def floor(that: Time): Time = { - val t = that.asInstanceOf[LongTime].milliseconds - val m = this.milliseconds / t - LongTime(m.toLong * t) - } + def isMultipleOf(that: Time): Boolean = + (this.millis % that.millis == 0) - override def isMultipleOf(that: Time): Boolean = - (this.milliseconds % that.asInstanceOf[LongTime].milliseconds == 0) + def isZero = (this.millis == 0) - override def isZero = (this.milliseconds == 0) + override def toString() = (millis.toString + " ms") - override def toString = (milliseconds.toString + "ms") + def toFormattedString() = millis.toString + + def milliseconds() = millis +} - override def toFormattedString = milliseconds.toString +object Time { + val zero = new Time(0) + + def apply(milliseconds: Long) = new Time(milliseconds) + + implicit def toTime(long: Long) = Time(long) + + implicit def toLong(time: Time) = time.milliseconds } object Milliseconds { - def apply(milliseconds: Long) = LongTime(milliseconds) + def apply(milliseconds: Long) = Time(milliseconds) } object Seconds { - def apply(seconds: Long) = LongTime(seconds * 1000) + def apply(seconds: Long) = Time(seconds * 1000) } object Minutes { - def apply(minutes: Long) = LongTime(minutes * 60000) + def apply(minutes: Long) = Time(minutes * 60000) } diff --git a/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala deleted file mode 100644 index 2ca72da79f..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/DumbTopKWordCount2_Special.scala +++ /dev/null @@ -1,138 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object DumbTopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - /*println("count = " + count)*/ - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala deleted file mode 100644 index 34e7edfda9..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/DumbWordCount2_Special.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -object DumbWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - - map.toIterator - } - - val wordCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - wordCounts.persist(StorageLevel.MEMORY_ONLY) - val windowedCounts = wordCounts.window(Seconds(10), Seconds(1)).reduceByKey(_ + _, 10) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala new file mode 100644 index 0000000000..d56fdcdf29 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala @@ -0,0 +1,41 @@ +package spark.streaming.examples + +import spark.RDD +import spark.streaming.SparkStreamContext +import spark.streaming.SparkStreamContext._ +import spark.streaming.Seconds + +import scala.collection.mutable.SynchronizedQueue + +object ExampleOne { + + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: ExampleOne ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new SparkStreamContext(args(0), "ExampleOne") + ssc.setBatchDuration(Seconds(1)) + + // Create the queue through which RDDs can be pushed to + // a QueueInputRDS + val rddQueue = new SynchronizedQueue[RDD[Int]]() + + // Create the QueueInputRDs and use it do some processing + val inputStream = ssc.createQueueStream(rddQueue) + val mappedStream = inputStream.map(x => (x % 10, 1)) + val reducedStream = mappedStream.reduceByKey(_ + _) + reducedStream.print() + ssc.start() + + // Create and push some RDDs into + for (i <- 1 to 30) { + rddQueue += ssc.sc.makeRDD(1 to 1000, 10) + Thread.sleep(1000) + } + ssc.stop() + System.exit(0) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala new file mode 100644 index 0000000000..4b8f6d609d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala @@ -0,0 +1,47 @@ +package spark.streaming.examples + +import spark.streaming.SparkStreamContext +import spark.streaming.SparkStreamContext._ +import spark.streaming.Seconds +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + + +object ExampleTwo { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: ExampleOne ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new SparkStreamContext(args(0), "ExampleTwo") + ssc.setBatchDuration(Seconds(2)) + + // Create the new directory + val directory = new Path(args(1)) + val fs = directory.getFileSystem(new Configuration()) + if (fs.exists(directory)) throw new Exception("This directory already exists") + fs.mkdirs(directory) + + // Create the FileInputRDS on the directory and use the + // stream to count words in new files created + val inputRDS = ssc.createTextFileStream(directory.toString) + val wordsRDS = inputRDS.flatMap(_.split(" ")) + val wordCountsRDS = wordsRDS.map(x => (x, 1)).reduceByKey(_ + _) + wordCountsRDS.print + ssc.start() + + // Creating new files in the directory + val text = "This is a text file" + for (i <- 1 to 30) { + ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) + .saveAsTextFile(new Path(directory, i.toString).toString) + Thread.sleep(1000) + } + Thread.sleep(5000) // Waiting for the file to be processed + ssc.stop() + fs.delete(directory) + System.exit(0) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala deleted file mode 100644 index ec3e70f258..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepCount.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: GrepCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala deleted file mode 100644 index 27ecced1c0..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepCount2.scala +++ /dev/null @@ -1,113 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkEnv -import spark.SparkContext -import spark.storage.StorageLevel -import spark.network.Message -import spark.network.ConnectionManagerId - -import java.nio.ByteBuffer - -object GrepCount2 { - - def startSparkEnvs(sc: SparkContext) { - - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - println("SparkEnvs started") - Thread.sleep(1000) - /*sc.runJob(sc.parallelize(0 to 1000, 100), (_: Iterator[Int]) => {})*/ - } - - def warmConnectionManagers(sc: SparkContext) { - val slaveConnManagerIds = sc.parallelize(0 to 100, 100).map( - i => SparkEnv.get.connectionManager.id).collect().distinct - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - Thread.sleep(1000) - val numSlaves = slaveConnManagerIds.size - val count = 3 - val size = 5 * 1024 * 1024 - val iterations = (500 * 1024 * 1024 / (numSlaves * size)).toInt - println("count = " + count + ", size = " + size + ", iterations = " + iterations) - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until numSlaves, numSlaves).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - /*connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - })*/ - - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = (0 until iterations).map(i => { - slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - println("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - }) - }).flatMap(x => x) - val results = futures.map(f => f()) - val finishTime = System.currentTimeMillis - - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" - println(resultStr) - System.gc() - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } - - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "GrepCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - /*startSparkEnvs(ssc.sc)*/ - warmConnectionManagers(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-"+i, 500)).toArray - ) - - val matching = sentences.filter(_.contains("light")) - matching.foreachRDD(rdd => println(rdd.count)) - - ssc.run - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala b/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala deleted file mode 100644 index f9674136fe..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepCountApprox.scala +++ /dev/null @@ -1,54 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object GrepCountApprox { - var inputFile : String = null - var hdfs : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 5) { - println ("Usage: GrepCountApprox ") - System.exit(1) - } - - hdfs = args(1) - inputFile = hdfs + args(2) - idealPartitions = args(3).toInt - val timeout = args(4).toLong - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "GrepCount") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - ssc.setTempDir(hdfs + "/tmp") - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - val matching = sentences.filter(_.contains("light")) - var i = 0 - val startTime = System.currentTimeMillis - matching.foreachRDD { rdd => - val myNum = i - val result = rdd.countApprox(timeout) - val initialTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tinitial\t%.1f\t%.1f\n", initialTime, myNum, result.initialValue.mean, - result.initialValue.high - result.initialValue.low) - result.onComplete { r => - val finalTime = (System.currentTimeMillis - startTime) / 1000.0 - printf("APPROX\t%.2f\t%d\tfinal\t%.1f\t0.0\t%.1f\n", finalTime, myNum, r.mean, finalTime - initialTime) - } - i += 1 - } - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala deleted file mode 100644 index a75ccd3a56..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount.scala +++ /dev/null @@ -1,30 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1) - counts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala deleted file mode 100644 index 9672e64b13..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.util.Sorting - -object SimpleWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - - val words = sentences.flatMap(_.split(" ")) - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - /*words.foreachRDD(_.countByValue())*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala deleted file mode 100644 index 503033a8e5..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/SimpleWordCount2_Special.scala +++ /dev/null @@ -1,83 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import scala.collection.JavaConversions.mapAsScalaMap -import scala.util.Sorting -import java.lang.{Long => JLong} - -object SimpleWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SimpleWordCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "SimpleWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 400)).toArray - ) - - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - /*val words = sentences.flatMap(_.split(" "))*/ - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 10)*/ - val counts = sentences.mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - counts.foreachRDD(_.collect()) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala b/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala deleted file mode 100644 index 031e989c87..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopContentCount.scala +++ /dev/null @@ -1,97 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopContentCount { - - case class Event(val country: String, val content: String) - - object Event { - def create(string: String): Event = { - val parts = string.split(":") - new Event(parts(0), parts(1)) - } - } - - def main(args: Array[String]) { - - if (args.length < 2) { - println ("Usage: GrepCount2 <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopContentCount") - val sc = ssc.sc - val dummy = sc.parallelize(0 to 1000, 100).persist(StorageLevel.DISK_AND_MEMORY) - sc.runJob(dummy, (_: Iterator[Int]) => {}) - - - val numEventStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - val eventStrings = new UnifiedRDS( - (1 to numEventStreams).map(i => ssc.readTestStream("Events-" + i, 1000)).toArray - ) - - def parse(string: String) = { - val parts = string.split(":") - (parts(0), parts(1)) - } - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val events = eventStrings.map(x => parse(x)) - /*events.print*/ - - val parallelism = 8 - val counts_per_content_per_country = events - .map(x => (x, 1)) - .reduceByKey(_ + _) - /*.reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), parallelism)*/ - /*counts_per_content_per_country.print*/ - - /* - counts_per_content_per_country.persist( - StorageLevel.MEMORY_ONLY_DESER, - StorageLevel.MEMORY_ONLY_DESER_2, - Seconds(1) - )*/ - - val counts_per_country = counts_per_content_per_country - .map(x => (x._1._1, (x._1._2, x._2))) - .groupByKey() - counts_per_country.print - - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - val taken = array.take(k) - taken - } - - val k = 10 - val topKContents_per_country = counts_per_country - .map(x => (x._1, topK(x._2, k))) - .map(x => (x._1, x._2.map(_.toString).reduceLeft(_ + ", " + _))) - - topKContents_per_country.print - - ssc.run - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala deleted file mode 100644 index 679ed0a7ef..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object TopKWordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - moreWarmup(ssc.sc) - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - - def topK(data: Iterator[(String, Int)], k: Int): Iterator[(String, Int)] = { - val taken = new Array[(String, Int)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, Int) = null - var swap: (String, Int) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 10 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala deleted file mode 100644 index c873fbd0f0..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCount2_Special.scala +++ /dev/null @@ -1,142 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.Queue - -import java.lang.{Long => JLong} - -object TopKWordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "TopKWordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - /*val words = sentences.flatMap(_.split(" "))*/ - - /*def add(v1: Int, v2: Int) = (v1 + v2) */ - /*def subtract(v1: Int, v2: Int) = (v1 - v2) */ - - def add(v1: JLong, v2: JLong) = (v1 + v2) - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } - - - val windowedCounts = sentences.mapPartitions(splitAndCountPartitions).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Milliseconds(500), 10) - /*windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1))*/ - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } - - val k = 50 - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) - }) - - /* - windowedCounts.filter(_ == null).foreachRDD(rdd => { - val count = rdd.count - println("# of nulls = " + count) - })*/ - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala deleted file mode 100644 index fb5508ffcc..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala +++ /dev/null @@ -1,62 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - //val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - //val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(_ + _) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - parallelism) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala deleted file mode 100644 index 42d985920a..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount1.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordCount1 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala deleted file mode 100644 index 9168a2fe2f..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ /dev/null @@ -1,55 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting - -object WordCount2 { - - def moreWarmup(sc: SparkContext) { - (0 until 20).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(_ % 100).map(_.toString) - .map(x => (x, 1)).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - if (args.length > 2) { - ssc.setTempDir(args(2)) - } - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 1000)).toArray - ) - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 6) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(1)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala deleted file mode 100644 index 1920915af7..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2_Special.scala +++ /dev/null @@ -1,94 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import SparkContext._ -import SparkStreamContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import java.lang.{Long => JLong} -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - - -object WordCount2_ExtraFunctions { - - def add(v1: JLong, v2: JLong) = (v1 + v2) - - def subtract(v1: JLong, v2: JLong) = (v1 - v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, JLong)] = { - val map = new java.util.HashMap[String, JLong] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.get(w) - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator - } -} - -object WordCount2_Special { - - def moreWarmup(sc: SparkContext) { - (0 until 40).foreach {i => - sc.parallelize(1 to 20000000, 1000) - .map(_ % 1331).map(_.toString) - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) - .collect() - } - } - - def main (args: Array[String]) { - - if (args.length < 2) { - println ("Usage: SparkStreamContext <# sentence streams>") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount2") - - val numSentenceStreams = if (args.length > 1) args(1).toInt else 1 - - GrepCount2.warmConnectionManagers(ssc.sc) - /*moreWarmup(ssc.sc)*/ - - val sentences = new UnifiedRDS( - (1 to numSentenceStreams).map(i => ssc.readTestStream("Sentences-" + i, 500)).toArray - ) - - val windowedCounts = sentences - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions) - .reduceByKeyAndWindow(WordCount2_ExtraFunctions.add _, WordCount2_ExtraFunctions.subtract _, Seconds(10), Milliseconds(500), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY, Milliseconds(500)) - windowedCounts.foreachRDD(_.collect) - - ssc.run - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala deleted file mode 100644 index 018c19a509..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount3.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCount3 { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCount") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - /*val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1)*/ - val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(5), Seconds(1), 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala deleted file mode 100644 index 82b9fa781d..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountEc2.scala +++ /dev/null @@ -1,41 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ -import spark.SparkContext - -object WordCountEc2 { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: SparkStreamContext ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val ssc = new SparkStreamContext(args(0), "Test") - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.foreach(println)*/ - - val words = sentences.flatMap(_.split(" ")) - /*words.foreach(println)*/ - - val counts = words.map(x => (x, 1)).reduceByKey(_ + _) - /*counts.foreach(println)*/ - - counts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - /*counts.register*/ - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala deleted file mode 100644 index 114dd144f1..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountTrivialWindow.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -object WordCountTrivialWindow { - - def main (args: Array[String]) { - - if (args.length < 1) { - println ("Usage: SparkStreamContext []") - System.exit(1) - } - - val ssc = new SparkStreamContext(args(0), "WordCountTrivialWindow") - if (args.length > 1) { - ssc.setTempDir(args(1)) - } - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 1000) - /*sentences.print*/ - - val words = sentences.flatMap(_.split(" ")) - - /*val counts = words.map(x => (x, 1)).reduceByKey(_ + _, 1)*/ - /*counts.print*/ - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - - val windowedCounts = words.map(x => (x, 1)).window(Seconds(5), Seconds(1)).reduceByKey(add _, 1) - /*windowedCounts.print */ - - def topK(data: Seq[(String, Int)], k: Int): Array[(String, Int)] = { - implicit val countOrdering = new Ordering[(String, Int)] { - override def compare(count1: (String, Int), count2: (String, Int)): Int = { - count2._2 - count1._2 - } - } - val array = data.toArray - Sorting.quickSort(array) - array.take(k) - } - - val k = 10 - val topKWindowedCounts = windowedCounts.glom.flatMap(topK(_, k)).collect.flatMap(topK(_, k)) - topKWindowedCounts.print - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax.scala deleted file mode 100644 index fbfc48030f..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax.scala +++ /dev/null @@ -1,64 +0,0 @@ -package spark.streaming - -import SparkStreamContext._ - -import scala.util.Sorting - -import spark.SparkContext -import spark.storage.StorageLevel - -object WordMax { - var inputFile : String = null - var HDFS : String = null - var idealPartitions : Int = 0 - - def main (args: Array[String]) { - - if (args.length != 4) { - println ("Usage: WordCount ") - System.exit(1) - } - - HDFS = args(1) - inputFile = HDFS + args(2) - idealPartitions = args(3).toInt - println ("Input file: " + inputFile) - - val ssc = new SparkStreamContext(args(0), "WordCountWindow") - - SparkContext.idealPartitions = idealPartitions - SparkContext.inputFile = inputFile - - val sentences = ssc.readNetworkStream[String]("Sentences", Array("localhost:55119"), 2000) - //sentences.print - - val words = sentences.flatMap(_.split(" ")) - - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - def max(v1: Int, v2: Int) = (if (v1 > v2) v1 else v2) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(2000), - // System.getProperty("spark.default.parallelism", "1").toInt) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.DISK_AND_MEMORY_DESER_2, Seconds(5)) - //windowedCounts.print - - val parallelism = System.getProperty("spark.default.parallelism", "1").toInt - - val localCounts = words.map(x => (x, 1)).reduceByKey(add _, parallelism) - //localCounts.persist(StorageLevel.MEMORY_ONLY_DESER) - localCounts.persist(StorageLevel.MEMORY_ONLY_DESER_2) - val windowedCounts = localCounts.window(Seconds(30), Seconds(2)).reduceByKey(max _) - - //val windowedCounts = words.map(x => (x, 1)).reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(2), - // parallelism) - //windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(6)) - - //windowedCounts.print - windowedCounts.register - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => print(x+ " "))) - //windowedCounts.foreachRDD(rdd => rdd.collect.foreach(x => x)) - - ssc.run - } -} diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala new file mode 100644 index 0000000000..6125bb82eb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -0,0 +1,52 @@ +package spark.streaming.util + +class RecurringTimer(period: Long, callback: (Long) => Unit) { + + val minPollTime = 25L + + val pollTime = { + if (period / 10.0 > minPollTime) { + (period / 10.0).toLong + } else { + minPollTime + } + } + + val thread = new Thread() { + override def run() { loop } + } + + var nextTime = 0L + + def start(): Long = { + nextTime = (math.floor(System.currentTimeMillis() / period) + 1).toLong * period + thread.start() + nextTime + } + + def stop() { + thread.interrupt() + } + + def loop() { + try { + while (true) { + val beforeSleepTime = System.currentTimeMillis() + while (beforeSleepTime >= nextTime) { + callback(nextTime) + nextTime += period + } + val sleepTime = if (nextTime - beforeSleepTime < 2 * pollTime) { + nextTime - beforeSleepTime + } else { + pollTime + } + Thread.sleep(sleepTime) + val afterSleepTime = System.currentTimeMillis() + } + } catch { + case e: InterruptedException => + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala new file mode 100644 index 0000000000..9925b1d07c --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala @@ -0,0 +1,64 @@ +package spark.streaming.util + +import java.net.{Socket, ServerSocket} +import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} + +object Receiver { + def main(args: Array[String]) { + val port = args(0).toInt + val lsocket = new ServerSocket(port) + println("Listening on port " + port ) + while(true) { + val socket = lsocket.accept() + (new Thread() { + override def run() { + val buffer = new Array[Byte](100000) + var count = 0 + val time = System.currentTimeMillis + try { + val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) + var loop = true + var string: String = null + while((string = is.readUTF) != null) { + count += 28 + } + } catch { + case e: Exception => e.printStackTrace + } + val timeTaken = System.currentTimeMillis - time + val tput = (count / 1024.0) / (timeTaken / 1000.0) + println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") + } + }).start() + } + } + +} + +object Sender { + + def main(args: Array[String]) { + try { + val host = args(0) + val port = args(1).toInt + val size = args(2).toInt + + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) + val bytes = byteStream.toByteArray() + println("Generated array of " + bytes.length + " bytes") + + /*val bytes = new Array[Byte](size)*/ + val socket = new Socket(host, port) + val os = socket.getOutputStream + os.write(bytes) + os.flush + socket.close() + + } catch { + case e: Exception => e.printStackTrace + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala new file mode 100644 index 0000000000..94e8f7a849 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala @@ -0,0 +1,92 @@ +package spark.streaming.util + +import spark._ + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.io.Source + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +object SentenceFileGenerator { + + def printUsage () { + println ("Usage: SentenceFileGenerator <# partitions> []") + System.exit(0) + } + + def main (args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val master = args(0) + val fs = new Path(args(1)).getFileSystem(new Configuration()) + val targetDirectory = new Path(args(1)).makeQualified(fs) + val numPartitions = args(2).toInt + val sentenceFile = args(3) + val sentencesPerSecond = { + if (args.length > 4) args(4).toInt + else 10 + } + + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n").toArray + source.close () + println("Read " + lines.length + " lines from file " + sentenceFile) + + val sentences = { + val buffer = ArrayBuffer[String]() + val random = new Random() + var i = 0 + while (i < sentencesPerSecond) { + buffer += lines(random.nextInt(lines.length)) + i += 1 + } + buffer.toArray + } + println("Generated " + sentences.length + " sentences") + + val sc = new SparkContext(master, "SentenceFileGenerator") + val sentencesRDD = sc.parallelize(sentences, numPartitions) + + val tempDirectory = new Path(targetDirectory, "_tmp") + + fs.mkdirs(targetDirectory) + fs.mkdirs(tempDirectory) + + var saveTimeMillis = System.currentTimeMillis + try { + while (true) { + val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) + val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) + println("Writing to file " + newDir) + sentencesRDD.saveAsTextFile(tmpNewDir.toString) + fs.rename(tmpNewDir, newDir) + saveTimeMillis += 1000 + val sleepTimeMillis = { + val currentTimeMillis = System.currentTimeMillis + if (saveTimeMillis < currentTimeMillis) { + 0 + } else { + saveTimeMillis - currentTimeMillis + } + } + println("Sleeping for " + sleepTimeMillis + " ms") + Thread.sleep(sleepTimeMillis) + } + } catch { + case e: Exception => + } + } +} + + + + diff --git a/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala new file mode 100644 index 0000000000..60085f4f88 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala @@ -0,0 +1,23 @@ +package spark.streaming.util + +import spark.SparkContext +import SparkContext._ + +object ShuffleTest { + def main(args: Array[String]) { + + if (args.length < 1) { + println ("Usage: ShuffleTest ") + System.exit(1) + } + + val sc = new spark.SparkContext(args(0), "ShuffleTest") + val rdd = sc.parallelize(1 to 1000, 500).cache + + def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } + + time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } + System.exit(0) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/util/Utils.scala b/streaming/src/main/scala/spark/streaming/util/Utils.scala new file mode 100644 index 0000000000..86a729fb49 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/Utils.scala @@ -0,0 +1,9 @@ +package spark.streaming.util + +object Utils { + def time(func: => Unit): Long = { + val t = System.currentTimeMillis + func + (System.currentTimeMillis - t) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala b/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala deleted file mode 100644 index bb32089ae2..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SenGeneratorForPerformanceTest.scala +++ /dev/null @@ -1,78 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - -/*import akka.actor.Actor._*/ -/*import akka.actor.ActorRef*/ - - -object SenGeneratorForPerformanceTest { - - def printUsage () { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val inputManagerIP = args(0) - val inputManagerPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = { - if (args.length > 3) args(3).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - try { - /*val inputManager = remote.actorFor("InputReceiver-Sentences",*/ - /* inputManagerIP, inputManagerPort)*/ - val inputManager = select(Node(inputManagerIP, inputManagerPort), Symbol("InputReceiver-Sentences")) - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - println ("Sending " + sentencesPerSecond + " sentences per second to " + inputManagerIP + ":" + inputManagerPort) - var lastPrintTime = System.currentTimeMillis() - var count = 0 - - while (true) { - /*if (!inputManager.tryTell (lines (random.nextInt (lines.length))))*/ - /*throw new Exception ("disconnected")*/ -// inputManager ! lines (random.nextInt (lines.length)) - for (i <- 0 to sentencesPerSecond) inputManager ! lines (0) - println(System.currentTimeMillis / 1000 + " s") -/* count += 1 - - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - - Thread.sleep (sleepBetweenSentences.toLong) -*/ - val currentMs = System.currentTimeMillis / 1000; - Thread.sleep ((currentMs * 1000 + 1000) - System.currentTimeMillis) - } - } catch { - case e: Exception => - /*Thread.sleep (1000)*/ - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala deleted file mode 100644 index 6af270298a..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SenderReceiverTest.scala +++ /dev/null @@ -1,63 +0,0 @@ -package spark.streaming -import java.net.{Socket, ServerSocket} -import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} - -object Receiver { - def main(args: Array[String]) { - val port = args(0).toInt - val lsocket = new ServerSocket(port) - println("Listening on port " + port ) - while(true) { - val socket = lsocket.accept() - (new Thread() { - override def run() { - val buffer = new Array[Byte](100000) - var count = 0 - val time = System.currentTimeMillis - try { - val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) - var loop = true - var string: String = null - while((string = is.readUTF) != null) { - count += 28 - } - } catch { - case e: Exception => e.printStackTrace - } - val timeTaken = System.currentTimeMillis - time - val tput = (count / 1024.0) / (timeTaken / 1000.0) - println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") - } - }).start() - } - } - -} - -object Sender { - - def main(args: Array[String]) { - try { - val host = args(0) - val port = args(1).toInt - val size = args(2).toInt - - val byteStream = new ByteArrayOutputStream() - val stringDataStream = new DataOutputStream(byteStream) - (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) - val bytes = byteStream.toByteArray() - println("Generated array of " + bytes.length + " bytes") - - /*val bytes = new Array[Byte](size)*/ - val socket = new Socket(host, port) - val os = socket.getOutputStream - os.write(bytes) - os.flush - socket.close() - - } catch { - case e: Exception => e.printStackTrace - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala deleted file mode 100644 index 15858f59e3..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SentenceFileGenerator.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.streaming - -import spark._ - -import scala.collection.mutable.ArrayBuffer -import scala.util.Random -import scala.io.Source - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -object SentenceFileGenerator { - - def printUsage () { - println ("Usage: SentenceFileGenerator <# partitions> []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 4) { - printUsage - } - - val master = args(0) - val fs = new Path(args(1)).getFileSystem(new Configuration()) - val targetDirectory = new Path(args(1)).makeQualified(fs) - val numPartitions = args(2).toInt - val sentenceFile = args(3) - val sentencesPerSecond = { - if (args.length > 4) args(4).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n").toArray - source.close () - println("Read " + lines.length + " lines from file " + sentenceFile) - - val sentences = { - val buffer = ArrayBuffer[String]() - val random = new Random() - var i = 0 - while (i < sentencesPerSecond) { - buffer += lines(random.nextInt(lines.length)) - i += 1 - } - buffer.toArray - } - println("Generated " + sentences.length + " sentences") - - val sc = new SparkContext(master, "SentenceFileGenerator") - val sentencesRDD = sc.parallelize(sentences, numPartitions) - - val tempDirectory = new Path(targetDirectory, "_tmp") - - fs.mkdirs(targetDirectory) - fs.mkdirs(tempDirectory) - - var saveTimeMillis = System.currentTimeMillis - try { - while (true) { - val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) - val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) - println("Writing to file " + newDir) - sentencesRDD.saveAsTextFile(tmpNewDir.toString) - fs.rename(tmpNewDir, newDir) - saveTimeMillis += 1000 - val sleepTimeMillis = { - val currentTimeMillis = System.currentTimeMillis - if (saveTimeMillis < currentTimeMillis) { - 0 - } else { - saveTimeMillis - currentTimeMillis - } - } - println("Sleeping for " + sleepTimeMillis + " ms") - Thread.sleep(sleepTimeMillis) - } - } catch { - case e: Exception => - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala b/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala deleted file mode 100644 index a9f124d2d7..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/SentenceGenerator.scala +++ /dev/null @@ -1,103 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - - -object SentenceGenerator { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - - try { - var lastPrintTime = System.currentTimeMillis() - var count = 0 - while(true) { - streamReceiver ! lines(random.nextInt(lines.length)) - count += 1 - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - Thread.sleep(sleepBetweenSentences.toLong) - } - } catch { - case e: Exception => - } - } - - def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - try { - val numSentences = if (sentencesPerSecond <= 0) { - lines.length - } else { - sentencesPerSecond - } - var nextSendingTime = System.currentTimeMillis() - val pingInterval = if (System.getenv("INTERVAL") != null) { - System.getenv("INTERVAL").toInt - } else { - 2000 - } - while(true) { - (0 until numSentences).foreach(i => { - streamReceiver ! lines(i % lines.length) - }) - println ("Sent " + numSentences + " sentences") - nextSendingTime += pingInterval - val sleepTime = nextSendingTime - System.currentTimeMillis - if (sleepTime > 0) { - println ("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } - } - } catch { - case e: Exception => - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val generateRandomly = false - - val streamReceiverIP = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 - val sentenceInputName = if (args.length > 4) args(4) else "Sentences" - - println("Sending " + sentencesPerSecond + " sentences per second to " + - streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - val streamReceiver = select( - Node(streamReceiverIP, streamReceiverPort), - Symbol("NetworkStreamReceiver-" + sentenceInputName)) - if (generateRandomly) { - generateRandomSentences(lines, sentencesPerSecond, streamReceiver) - } else { - generateSameSentences(lines, sentencesPerSecond, streamReceiver) - } - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala deleted file mode 100644 index 32aa4144a0..0000000000 --- a/streaming/src/main/scala/spark/streaming/utils/ShuffleTest.scala +++ /dev/null @@ -1,22 +0,0 @@ -package spark.streaming -import spark.SparkContext -import SparkContext._ - -object ShuffleTest { - def main(args: Array[String]) { - - if (args.length < 1) { - println ("Usage: ShuffleTest ") - System.exit(1) - } - - val sc = new spark.SparkContext(args(0), "ShuffleTest") - val rdd = sc.parallelize(1 to 1000, 500).cache - - def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } - - time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } - System.exit(0) - } -} - diff --git a/streaming/src/test/scala/spark/streaming/RDSSuite.scala b/streaming/src/test/scala/spark/streaming/RDSSuite.scala new file mode 100644 index 0000000000..f51ea50a5d --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/RDSSuite.scala @@ -0,0 +1,65 @@ +package spark.streaming + +import spark.RDD + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.SynchronizedQueue + +class RDSSuite extends FunSuite with BeforeAndAfter { + + var ssc: SparkStreamContext = null + val batchDurationMillis = 1000 + + def testOp[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: RDS[U] => RDS[V], + expectedOutput: Seq[Seq[V]]) = { + try { + ssc = new SparkStreamContext("local", "test") + ssc.setBatchDuration(Milliseconds(batchDurationMillis)) + + val inputStream = ssc.createQueueStream(input.map(ssc.sc.makeRDD(_, 2)).toIterator) + val outputStream = operation(inputStream) + val outputQueue = outputStream.toQueue + + ssc.start() + Thread.sleep(batchDurationMillis * input.size) + + val output = new ArrayBuffer[Seq[V]]() + while(outputQueue.size > 0) { + val rdd = outputQueue.take() + println("Collecting RDD " + rdd.id + ", " + rdd.getClass().getSimpleName() + ", " + rdd.splits.size) + output += (rdd.collect()) + } + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i).toList === expectedOutput(i).toList) + } + } finally { + ssc.stop() + } + } + + test("basic operations") { + val inputData = Array(1 to 4, 5 to 8, 9 to 12) + + // map + testOp(inputData, (r: RDS[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + + // flatMap + testOp(inputData, (r: RDS[Int]) => r.flatMap(x => Array(x, x * 2)), + inputData.map(_.flatMap(x => Array(x, x * 2))) + ) + } +} + +object RDSSuite { + def main(args: Array[String]) { + val r = new RDSSuite() + val inputData = Array(1 to 4, 5 to 8, 9 to 12) + r.testOp(inputData, (r: RDS[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + } +} \ No newline at end of file -- cgit v1.2.3 From ed897ac5e1d1dd2ce144b232cf7a73db2d6679f9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 1 Aug 2012 22:28:54 -0700 Subject: Moved streaming files not immediately necessary to spark.streaming.util. --- .../scala/spark/streaming/ConnectionHandler.scala | 157 -------- streaming/src/main/scala/spark/streaming/Job.scala | 7 +- .../main/scala/spark/streaming/TestGenerator.scala | 107 ------ .../scala/spark/streaming/TestGenerator2.scala | 119 ------ .../scala/spark/streaming/TestGenerator4.scala | 244 ------------ .../spark/streaming/TestStreamCoordinator.scala | 38 -- .../spark/streaming/TestStreamReceiver3.scala | 420 -------------------- .../spark/streaming/TestStreamReceiver4.scala | 373 ------------------ .../spark/streaming/util/ConnectionHandler.scala | 157 ++++++++ .../scala/spark/streaming/util/TestGenerator.scala | 107 ++++++ .../spark/streaming/util/TestGenerator2.scala | 119 ++++++ .../spark/streaming/util/TestGenerator4.scala | 244 ++++++++++++ .../streaming/util/TestStreamCoordinator.scala | 39 ++ .../spark/streaming/util/TestStreamReceiver3.scala | 421 +++++++++++++++++++++ .../spark/streaming/util/TestStreamReceiver4.scala | 374 ++++++++++++++++++ .../main/scala/spark/streaming/util/Utils.scala | 9 - 16 files changed, 1465 insertions(+), 1470 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/ConnectionHandler.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TestGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TestGenerator2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TestGenerator4.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TestStreamCoordinator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/TestGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/Utils.scala diff --git a/streaming/src/main/scala/spark/streaming/ConnectionHandler.scala b/streaming/src/main/scala/spark/streaming/ConnectionHandler.scala deleted file mode 100644 index a4f454632f..0000000000 --- a/streaming/src/main/scala/spark/streaming/ConnectionHandler.scala +++ /dev/null @@ -1,157 +0,0 @@ -package spark.streaming - -import spark.Logging - -import scala.collection.mutable.{ArrayBuffer, SynchronizedQueue} - -import java.net._ -import java.io._ -import java.nio._ -import java.nio.charset._ -import java.nio.channels._ -import java.nio.channels.spi._ - -abstract class ConnectionHandler(host: String, port: Int, connect: Boolean) -extends Thread with Logging { - - val selector = SelectorProvider.provider.openSelector() - val interestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - - initLogging() - - override def run() { - try { - if (connect) { - connect() - } else { - listen() - } - - var interrupted = false - while(!interrupted) { - - preSelect() - - while(!interestChangeRequests.isEmpty) { - val (key, ops) = interestChangeRequests.dequeue - val lastOps = key.interestOps() - key.interestOps(ops) - - def intToOpStr(op: Int): String = { - val opStrs = new ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed ops from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - - selector.select() - interrupted = Thread.currentThread.isInterrupted - - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext) { - val key = selectedKeys.next.asInstanceOf[SelectionKey] - selectedKeys.remove() - if (key.isValid) { - if (key.isAcceptable) { - accept(key) - } else if (key.isConnectable) { - finishConnect(key) - } else if (key.isReadable) { - read(key) - } else if (key.isWritable) { - write(key) - } - } - } - } - } catch { - case e: Exception => { - logError("Error in select loop", e) - } - } - } - - def connect() { - val socketAddress = new InetSocketAddress(host, port) - val channel = SocketChannel.open() - channel.configureBlocking(false) - channel.socket.setReuseAddress(true) - channel.socket.setTcpNoDelay(true) - channel.connect(socketAddress) - channel.register(selector, SelectionKey.OP_CONNECT) - logInfo("Initiating connection to [" + socketAddress + "]") - } - - def listen() { - val channel = ServerSocketChannel.open() - channel.configureBlocking(false) - channel.socket.setReuseAddress(true) - channel.socket.setReceiveBufferSize(256 * 1024) - channel.socket.bind(new InetSocketAddress(port)) - channel.register(selector, SelectionKey.OP_ACCEPT) - logInfo("Listening on port " + port) - } - - def finishConnect(key: SelectionKey) { - try { - val channel = key.channel.asInstanceOf[SocketChannel] - val address = channel.socket.getRemoteSocketAddress - channel.finishConnect() - logInfo("Connected to [" + host + ":" + port + "]") - ready(key) - } catch { - case e: IOException => { - logError("Error finishing connect to " + host + ":" + port) - close(key) - } - } - } - - def accept(key: SelectionKey) { - try { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - val channel = serverChannel.accept() - val address = channel.socket.getRemoteSocketAddress - channel.configureBlocking(false) - logInfo("Accepted connection from [" + address + "]") - ready(channel.register(selector, 0)) - } catch { - case e: IOException => { - logError("Error accepting connection", e) - } - } - } - - def changeInterest(key: SelectionKey, ops: Int) { - logTrace("Added request to change ops to " + ops) - interestChangeRequests += ((key, ops)) - } - - def ready(key: SelectionKey) - - def preSelect() { - } - - def read(key: SelectionKey) { - throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) - } - - def write(key: SelectionKey) { - throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) - } - - def close(key: SelectionKey) { - try { - key.channel.close() - key.cancel() - Thread.currentThread.interrupt - } catch { - case e: Exception => logError("Error closing connection", e) - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala index 36958dafe1..2481a9a3ef 100644 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -1,11 +1,12 @@ package spark.streaming -import spark.streaming.util.Utils - class Job(val time: Time, func: () => _) { val id = Job.getNewId() def run(): Long = { - Utils.time { func() } + val startTime = System.currentTimeMillis + func() + val stopTime = System.currentTimeMillis + (startTime - stopTime) } override def toString = "streaming job " + id + " @ " + time diff --git a/streaming/src/main/scala/spark/streaming/TestGenerator.scala b/streaming/src/main/scala/spark/streaming/TestGenerator.scala deleted file mode 100644 index 0ff6af61f2..0000000000 --- a/streaming/src/main/scala/spark/streaming/TestGenerator.scala +++ /dev/null @@ -1,107 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - - -object TestGenerator { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - /* - def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - - try { - var lastPrintTime = System.currentTimeMillis() - var count = 0 - while(true) { - streamReceiver ! lines(random.nextInt(lines.length)) - count += 1 - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - Thread.sleep(sleepBetweenSentences.toLong) - } - } catch { - case e: Exception => - } - }*/ - - def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - try { - val numSentences = if (sentencesPerSecond <= 0) { - lines.length - } else { - sentencesPerSecond - } - val sentences = lines.take(numSentences).toArray - - var nextSendingTime = System.currentTimeMillis() - val sendAsArray = true - while(true) { - if (sendAsArray) { - println("Sending as array") - streamReceiver !? sentences - } else { - println("Sending individually") - sentences.foreach(sentence => { - streamReceiver !? sentence - }) - } - println ("Sent " + numSentences + " sentences in " + (System.currentTimeMillis - nextSendingTime) + " ms") - nextSendingTime += 1000 - val sleepTime = nextSendingTime - System.currentTimeMillis - if (sleepTime > 0) { - println ("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } - } - } catch { - case e: Exception => - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val generateRandomly = false - - val streamReceiverIP = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 - val sentenceInputName = if (args.length > 4) args(4) else "Sentences" - - println("Sending " + sentencesPerSecond + " sentences per second to " + - streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - val streamReceiver = select( - Node(streamReceiverIP, streamReceiverPort), - Symbol("NetworkStreamReceiver-" + sentenceInputName)) - if (generateRandomly) { - /*generateRandomSentences(lines, sentencesPerSecond, streamReceiver)*/ - } else { - generateSameSentences(lines, sentencesPerSecond, streamReceiver) - } - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/TestGenerator2.scala b/streaming/src/main/scala/spark/streaming/TestGenerator2.scala deleted file mode 100644 index 00d43604d0..0000000000 --- a/streaming/src/main/scala/spark/streaming/TestGenerator2.scala +++ /dev/null @@ -1,119 +0,0 @@ -package spark.streaming - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.io.{DataOutputStream, ByteArrayOutputStream, DataInputStream} -import java.net.Socket - -object TestGenerator2 { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def sendSentences(streamReceiverHost: String, streamReceiverPort: Int, numSentences: Int, bytes: Array[Byte], intervalTime: Long){ - try { - println("Connecting to " + streamReceiverHost + ":" + streamReceiverPort) - val socket = new Socket(streamReceiverHost, streamReceiverPort) - - println("Sending " + numSentences+ " sentences / " + (bytes.length / 1024.0 / 1024.0) + " MB per " + intervalTime + " ms to " + streamReceiverHost + ":" + streamReceiverPort ) - val currentTime = System.currentTimeMillis - var targetTime = (currentTime / intervalTime + 1).toLong * intervalTime - Thread.sleep(targetTime - currentTime) - - while(true) { - val startTime = System.currentTimeMillis() - println("Sending at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") - val socketOutputStream = socket.getOutputStream - val parts = 10 - (0 until parts).foreach(i => { - val partStartTime = System.currentTimeMillis - - val offset = (i * bytes.length / parts).toInt - val len = math.min(((i + 1) * bytes.length / parts).toInt - offset, bytes.length) - socketOutputStream.write(bytes, offset, len) - socketOutputStream.flush() - val partFinishTime = System.currentTimeMillis - println("Sending part " + i + " of " + len + " bytes took " + (partFinishTime - partStartTime) + " ms") - val sleepTime = math.max(0, 1000 / parts - (partFinishTime - partStartTime) - 1) - Thread.sleep(sleepTime) - }) - - socketOutputStream.flush() - /*val socketInputStream = new DataInputStream(socket.getInputStream)*/ - /*val reply = socketInputStream.readUTF()*/ - val finishTime = System.currentTimeMillis() - println ("Sent " + bytes.length + " bytes in " + (finishTime - startTime) + " ms for interval [" + targetTime + ", " + (targetTime + intervalTime) + "]") - /*println("Received = " + reply)*/ - targetTime = targetTime + intervalTime - val sleepTime = (targetTime - finishTime) + 10 - if (sleepTime > 0) { - println("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } else { - println("############################") - println("###### Skipping sleep ######") - println("############################") - } - } - } catch { - case e: Exception => println(e) - } - println("Stopped sending") - } - - def main(args: Array[String]) { - if (args.length < 4) { - printUsage - } - - val streamReceiverHost = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val intervalTime = args(3).toLong - val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 - - println("Reading the file " + sentenceFile) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close() - - val numSentences = if (sentencesPerInterval <= 0) { - lines.length - } else { - sentencesPerInterval - } - - println("Generating sentences") - val sentences: Array[String] = if (numSentences <= lines.length) { - lines.take(numSentences).toArray - } else { - (0 until numSentences).map(i => lines(i % lines.length)).toArray - } - - println("Converting to byte array") - val byteStream = new ByteArrayOutputStream() - val stringDataStream = new DataOutputStream(byteStream) - /*stringDataStream.writeInt(sentences.size)*/ - sentences.foreach(stringDataStream.writeUTF) - val bytes = byteStream.toByteArray() - stringDataStream.close() - println("Generated array of " + bytes.length + " bytes") - - /*while(true) { */ - sendSentences(streamReceiverHost, streamReceiverPort, numSentences, bytes, intervalTime) - /*println("Sleeping for 5 seconds")*/ - /*Thread.sleep(5000)*/ - /*System.gc()*/ - /*}*/ - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/TestGenerator4.scala b/streaming/src/main/scala/spark/streaming/TestGenerator4.scala deleted file mode 100644 index 93c7f2f440..0000000000 --- a/streaming/src/main/scala/spark/streaming/TestGenerator4.scala +++ /dev/null @@ -1,244 +0,0 @@ -package spark.streaming - -import spark.Logging - -import scala.util.Random -import scala.io.Source -import scala.collection.mutable.{ArrayBuffer, Queue} - -import java.net._ -import java.io._ -import java.nio._ -import java.nio.charset._ -import java.nio.channels._ - -import it.unimi.dsi.fastutil.io._ - -class TestGenerator4(targetHost: String, targetPort: Int, sentenceFile: String, intervalDuration: Long, sentencesPerInterval: Int) -extends Logging { - - class SendingConnectionHandler(host: String, port: Int, generator: TestGenerator4) - extends ConnectionHandler(host, port, true) { - - val buffers = new ArrayBuffer[ByteBuffer] - val newBuffers = new Queue[ByteBuffer] - var activeKey: SelectionKey = null - - def send(buffer: ByteBuffer) { - logDebug("Sending: " + buffer) - newBuffers.synchronized { - newBuffers.enqueue(buffer) - } - selector.wakeup() - buffer.synchronized { - buffer.wait() - } - } - - override def ready(key: SelectionKey) { - logDebug("Ready") - activeKey = key - val channel = key.channel.asInstanceOf[SocketChannel] - channel.register(selector, SelectionKey.OP_WRITE) - generator.startSending() - } - - override def preSelect() { - newBuffers.synchronized { - while(!newBuffers.isEmpty) { - val buffer = newBuffers.dequeue - buffers += buffer - logDebug("Added: " + buffer) - changeInterest(activeKey, SelectionKey.OP_WRITE) - } - } - } - - override def write(key: SelectionKey) { - try { - /*while(true) {*/ - val channel = key.channel.asInstanceOf[SocketChannel] - if (buffers.size > 0) { - val buffer = buffers(0) - val newBuffer = buffer.slice() - newBuffer.limit(math.min(newBuffer.remaining, 32768)) - val bytesWritten = channel.write(newBuffer) - buffer.position(buffer.position + bytesWritten) - if (bytesWritten == 0) return - if (buffer.remaining == 0) { - buffers -= buffer - buffer.synchronized { - buffer.notify() - } - } - /*changeInterest(key, SelectionKey.OP_WRITE)*/ - } else { - changeInterest(key, 0) - } - /*}*/ - } catch { - case e: IOException => { - if (e.toString.contains("pipe") || e.toString.contains("reset")) { - logError("Connection broken") - } else { - logError("Connection error", e) - } - close(key) - } - } - } - - override def close(key: SelectionKey) { - buffers.clear() - super.close(key) - } - } - - initLogging() - - val connectionHandler = new SendingConnectionHandler(targetHost, targetPort, this) - var sendingThread: Thread = null - var sendCount = 0 - val sendBatches = 5 - - def run() { - logInfo("Connection handler started") - connectionHandler.start() - connectionHandler.join() - if (sendingThread != null && !sendingThread.isInterrupted) { - sendingThread.interrupt - } - logInfo("Connection handler stopped") - } - - def startSending() { - sendingThread = new Thread() { - override def run() { - logInfo("STARTING TO SEND") - sendSentences() - logInfo("SENDING STOPPED AFTER " + sendCount) - connectionHandler.interrupt() - } - } - sendingThread.start() - } - - def stopSending() { - sendingThread.interrupt() - } - - def sendSentences() { - logInfo("Reading the file " + sentenceFile) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close() - - val numSentences = if (sentencesPerInterval <= 0) { - lines.length - } else { - sentencesPerInterval - } - - logInfo("Generating sentence buffer") - val sentences: Array[String] = if (numSentences <= lines.length) { - lines.take(numSentences).toArray - } else { - (0 until numSentences).map(i => lines(i % lines.length)).toArray - } - - /* - val sentences: Array[String] = if (numSentences <= lines.length) { - lines.take((numSentences / sendBatches).toInt).toArray - } else { - (0 until (numSentences/sendBatches)).map(i => lines(i % lines.length)).toArray - }*/ - - - val serializer = new spark.KryoSerializer().newInstance() - val byteStream = new FastByteArrayOutputStream(100 * 1024 * 1024) - serializer.serializeStream(byteStream).writeAll(sentences.toIterator.asInstanceOf[Iterator[Any]]).close() - byteStream.trim() - val sentenceBuffer = ByteBuffer.wrap(byteStream.array) - - logInfo("Sending " + numSentences+ " sentences / " + sentenceBuffer.limit + " bytes per " + intervalDuration + " ms to " + targetHost + ":" + targetPort ) - val currentTime = System.currentTimeMillis - var targetTime = (currentTime / intervalDuration + 1).toLong * intervalDuration - Thread.sleep(targetTime - currentTime) - - val totalBytes = sentenceBuffer.limit - - while(true) { - val batchesInCurrentInterval = sendBatches // if (sendCount < 10) 1 else sendBatches - - val startTime = System.currentTimeMillis() - logDebug("Sending # " + sendCount + " at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") - - (0 until batchesInCurrentInterval).foreach(i => { - try { - val position = (i * totalBytes / sendBatches).toInt - val limit = if (i == sendBatches - 1) { - totalBytes - } else { - ((i + 1) * totalBytes / sendBatches).toInt - 1 - } - - val partStartTime = System.currentTimeMillis - sentenceBuffer.limit(limit) - connectionHandler.send(sentenceBuffer) - val partFinishTime = System.currentTimeMillis - val sleepTime = math.max(0, intervalDuration / sendBatches - (partFinishTime - partStartTime) - 1) - Thread.sleep(sleepTime) - - } catch { - case ie: InterruptedException => return - case e: Exception => e.printStackTrace() - } - }) - sentenceBuffer.rewind() - - val finishTime = System.currentTimeMillis() - /*logInfo ("Sent " + sentenceBuffer.limit + " bytes in " + (finishTime - startTime) + " ms")*/ - targetTime = targetTime + intervalDuration //+ (if (sendCount < 3) 1000 else 0) - - val sleepTime = (targetTime - finishTime) + 20 - if (sleepTime > 0) { - logInfo("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } else { - logInfo("###### Skipping sleep ######") - } - if (Thread.currentThread.isInterrupted) { - return - } - sendCount += 1 - } - } -} - -object TestGenerator4 { - def printUsage { - println("Usage: TestGenerator4 []") - System.exit(0) - } - - def main(args: Array[String]) { - println("GENERATOR STARTED") - if (args.length < 4) { - printUsage - } - - - val streamReceiverHost = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val intervalDuration = args(3).toLong - val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 - - while(true) { - val generator = new TestGenerator4(streamReceiverHost, streamReceiverPort, sentenceFile, intervalDuration, sentencesPerInterval) - generator.run() - Thread.sleep(2000) - } - println("GENERATOR STOPPED") - } -} diff --git a/streaming/src/main/scala/spark/streaming/TestStreamCoordinator.scala b/streaming/src/main/scala/spark/streaming/TestStreamCoordinator.scala deleted file mode 100644 index c658a036f9..0000000000 --- a/streaming/src/main/scala/spark/streaming/TestStreamCoordinator.scala +++ /dev/null @@ -1,38 +0,0 @@ -package spark.streaming - -import spark.Logging - -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ - -sealed trait TestStreamCoordinatorMessage -case class GetStreamDetails extends TestStreamCoordinatorMessage -case class GotStreamDetails(name: String, duration: Long) extends TestStreamCoordinatorMessage -case class TestStarted extends TestStreamCoordinatorMessage - -class TestStreamCoordinator(streamDetails: Array[(String, Long)]) extends Actor with Logging { - - var index = 0 - - initLogging() - - logInfo("Created") - - def receive = { - case TestStarted => { - sender ! "OK" - } - - case GetStreamDetails => { - val streamDetail = if (index >= streamDetails.length) null else streamDetails(index) - sender ! GotStreamDetails(streamDetail._1, streamDetail._2) - index += 1 - if (streamDetail != null) { - logInfo("Allocated " + streamDetail._1 + " (" + index + "/" + streamDetails.length + ")" ) - } - } - } - -} - diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala deleted file mode 100644 index bbf2c7bf5e..0000000000 --- a/streaming/src/main/scala/spark/streaming/TestStreamReceiver3.scala +++ /dev/null @@ -1,420 +0,0 @@ -package spark.streaming - -import spark._ -import spark.storage._ -import spark.util.AkkaUtils - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} - -import akka.actor._ -import akka.actor.Actor -import akka.dispatch._ -import akka.pattern.ask -import akka.util.duration._ - -import java.io.DataInputStream -import java.io.BufferedInputStream -import java.net.Socket -import java.net.ServerSocket -import java.util.LinkedHashMap - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -import spark.Utils - - -class TestStreamReceiver3(actorSystem: ActorSystem, blockManager: BlockManager) -extends Thread with Logging { - - - class DataHandler( - inputName: String, - longIntervalDuration: Time, - shortIntervalDuration: Time, - blockManager: BlockManager - ) - extends Logging { - - class Block(var id: String, var shortInterval: Interval) { - val data = ArrayBuffer[String]() - var pushed = false - def longInterval = getLongInterval(shortInterval) - def empty() = (data.size == 0) - def += (str: String) = (data += str) - override def toString() = "Block " + id - } - - class Bucket(val longInterval: Interval) { - val blocks = new ArrayBuffer[Block]() - var filled = false - def += (block: Block) = blocks += block - def empty() = (blocks.size == 0) - def ready() = (filled && !blocks.exists(! _.pushed)) - def blockIds() = blocks.map(_.id).toArray - override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" - } - - initLogging() - - val shortIntervalDurationMillis = shortIntervalDuration.toLong - val longIntervalDurationMillis = longIntervalDuration.toLong - - var currentBlock: Block = null - var currentBucket: Bucket = null - - val blocksForPushing = new Queue[Block]() - val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] - - val blockUpdatingThread = new Thread() { override def run() { keepUpdatingCurrentBlock() } } - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - def start() { - blockUpdatingThread.start() - blockPushingThread.start() - } - - def += (data: String) = addData(data) - - def addData(data: String) { - if (currentBlock == null) { - updateCurrentBlock() - } - currentBlock.synchronized { - currentBlock += data - } - } - - def getShortInterval(time: Time): Interval = { - val intervalBegin = time.floor(shortIntervalDuration) - Interval(intervalBegin, intervalBegin + shortIntervalDuration) - } - - def getLongInterval(shortInterval: Interval): Interval = { - val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) - Interval(intervalBegin, intervalBegin + longIntervalDuration) - } - - def updateCurrentBlock() { - /*logInfo("Updating current block")*/ - val currentTime = Time(System.currentTimeMillis) - val shortInterval = getShortInterval(currentTime) - val longInterval = getLongInterval(shortInterval) - - def createBlock(reuseCurrentBlock: Boolean = false) { - val newBlockId = inputName + "-" + longInterval.toFormattedString + "-" + currentBucket.blocks.size - if (!reuseCurrentBlock) { - val newBlock = new Block(newBlockId, shortInterval) - /*logInfo("Created " + currentBlock)*/ - currentBlock = newBlock - } else { - currentBlock.shortInterval = shortInterval - currentBlock.id = newBlockId - } - } - - def createBucket() { - val newBucket = new Bucket(longInterval) - buckets += ((longInterval, newBucket)) - currentBucket = newBucket - /*logInfo("Created " + currentBucket + ", " + buckets.size + " buckets")*/ - } - - if (currentBlock == null || currentBucket == null) { - createBucket() - currentBucket.synchronized { - createBlock() - } - return - } - - currentBlock.synchronized { - var reuseCurrentBlock = false - - if (shortInterval != currentBlock.shortInterval) { - if (!currentBlock.empty) { - blocksForPushing.synchronized { - blocksForPushing += currentBlock - blocksForPushing.notifyAll() - } - } - - currentBucket.synchronized { - if (currentBlock.empty) { - reuseCurrentBlock = true - } else { - currentBucket += currentBlock - } - - if (longInterval != currentBucket.longInterval) { - currentBucket.filled = true - if (currentBucket.ready) { - currentBucket.notifyAll() - } - createBucket() - } - } - - createBlock(reuseCurrentBlock) - } - } - } - - def pushBlock(block: Block) { - try{ - if (blockManager != null) { - logInfo("Pushing block") - val startTime = System.currentTimeMillis - - val bytes = blockManager.dataSerialize(block.data.toIterator) - val finishTime = System.currentTimeMillis - logInfo(block + " serialization delay is " + (finishTime - startTime) / 1000.0 + " s") - - blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_2) - /*blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ - /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY_DESER)*/ - /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY)*/ - val finishTime1 = System.currentTimeMillis - logInfo(block + " put delay is " + (finishTime1 - startTime) / 1000.0 + " s") - } else { - logWarning(block + " not put as block manager is null") - } - } catch { - case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) - } - } - - def getBucket(longInterval: Interval): Option[Bucket] = { - buckets.get(longInterval) - } - - def clearBucket(longInterval: Interval) { - buckets.remove(longInterval) - } - - def keepUpdatingCurrentBlock() { - logInfo("Thread to update current block started") - while(true) { - updateCurrentBlock() - val currentTimeMillis = System.currentTimeMillis - val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * - shortIntervalDurationMillis - currentTimeMillis + 1 - Thread.sleep(sleepTimeMillis) - } - } - - def keepPushingBlocks() { - var loop = true - logInfo("Thread to push blocks started") - while(loop) { - val block = blocksForPushing.synchronized { - if (blocksForPushing.size == 0) { - blocksForPushing.wait() - } - blocksForPushing.dequeue - } - pushBlock(block) - block.pushed = true - block.data.clear() - - val bucket = buckets(block.longInterval) - bucket.synchronized { - if (bucket.ready) { - bucket.notifyAll() - } - } - } - } - } - - - class ConnectionListener(port: Int, dataHandler: DataHandler) - extends Thread with Logging { - initLogging() - override def run { - try { - val listener = new ServerSocket(port) - logInfo("Listening on port " + port) - while (true) { - new ConnectionHandler(listener.accept(), dataHandler).start(); - } - listener.close() - } catch { - case e: Exception => logError("", e); - } - } - } - - class ConnectionHandler(socket: Socket, dataHandler: DataHandler) extends Thread with Logging { - initLogging() - override def run { - logInfo("New connection from " + socket.getInetAddress() + ":" + socket.getPort) - val bytes = new Array[Byte](100 * 1024 * 1024) - try { - - val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, 1024 * 1024)) - /*val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream))*/ - var str: String = null - str = inputStream.readUTF - while(str != null) { - dataHandler += str - str = inputStream.readUTF() - } - - /* - var loop = true - while(loop) { - val numRead = inputStream.read(bytes) - if (numRead < 0) { - loop = false - } - inbox += ((LongTime(SystemTime.currentTimeMillis), "test")) - }*/ - - inputStream.close() - } catch { - case e => logError("Error receiving data", e) - } - socket.close() - } - } - - initLogging() - - val masterHost = System.getProperty("spark.master.host") - val masterPort = System.getProperty("spark.master.port").toInt - - val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) - val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") - val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") - - logInfo("Getting stream details from master " + masterHost + ":" + masterPort) - - val timeout = 50 millis - - var started = false - while (!started) { - askActor[String](testStreamCoordinator, TestStarted) match { - case Some(str) => { - started = true - logInfo("TestStreamCoordinator started") - } - case None => { - logInfo("TestStreamCoordinator not started yet") - Thread.sleep(200) - } - } - } - - val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { - case Some(details) => details - case None => throw new Exception("Could not get stream details") - } - logInfo("Stream details received: " + streamDetails) - - val inputName = streamDetails.name - val intervalDurationMillis = streamDetails.duration - val intervalDuration = Time(intervalDurationMillis) - - val dataHandler = new DataHandler( - inputName, - intervalDuration, - Time(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), - blockManager) - - val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) - - // Send a message to an actor and return an option with its reply, or None if this times out - def askActor[T](actor: ActorRef, message: Any): Option[T] = { - try { - val future = actor.ask(message)(timeout) - return Some(Await.result(future, timeout).asInstanceOf[T]) - } catch { - case e: Exception => - logInfo("Error communicating with " + actor, e) - return None - } - } - - override def run() { - connListener.start() - dataHandler.start() - - var interval = Interval.currentInterval(intervalDuration) - var dataStarted = false - - while(true) { - waitFor(interval.endTime) - logInfo("Woken up at " + System.currentTimeMillis + " for " + interval) - dataHandler.getBucket(interval) match { - case Some(bucket) => { - logInfo("Found " + bucket + " for " + interval) - bucket.synchronized { - if (!bucket.ready) { - logInfo("Waiting for " + bucket) - bucket.wait() - logInfo("Wait over for " + bucket) - } - if (dataStarted || !bucket.empty) { - logInfo("Notifying " + bucket) - notifyScheduler(interval, bucket.blockIds) - dataStarted = true - } - bucket.blocks.clear() - dataHandler.clearBucket(interval) - } - } - case None => { - logInfo("Found none for " + interval) - if (dataStarted) { - logInfo("Notifying none") - notifyScheduler(interval, Array[String]()) - } - } - } - interval = interval.next - } - } - - def waitFor(time: Time) { - val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.milliseconds - if (currentTimeMillis < targetTimeMillis) { - val sleepTime = (targetTimeMillis - currentTimeMillis) - Thread.sleep(sleepTime + 1) - } - } - - def notifyScheduler(interval: Interval, blockIds: Array[String]) { - try { - sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime - val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 - logInfo("Pushing delay for " + time + " is " + delay + " s") - } catch { - case _ => logError("Exception notifying scheduler at interval " + interval) - } - } -} - -object TestStreamReceiver3 { - - val PORT = 9999 - val SHORT_INTERVAL_MILLIS = 100 - - def main(args: Array[String]) { - System.setProperty("spark.master.host", Utils.localHostName) - System.setProperty("spark.master.port", "7078") - val details = Array(("Sentences", 2000L)) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) - actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") - new TestStreamReceiver3(actorSystem, null).start() - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala deleted file mode 100644 index a2babb23f4..0000000000 --- a/streaming/src/main/scala/spark/streaming/TestStreamReceiver4.scala +++ /dev/null @@ -1,373 +0,0 @@ -package spark.streaming - -import spark._ -import spark.storage._ -import spark.util.AkkaUtils - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} - -import java.io._ -import java.nio._ -import java.nio.charset._ -import java.nio.channels._ -import java.util.concurrent.Executors - -import akka.actor._ -import akka.actor.Actor -import akka.dispatch._ -import akka.pattern.ask -import akka.util.duration._ - -class TestStreamReceiver4(actorSystem: ActorSystem, blockManager: BlockManager) -extends Thread with Logging { - - class DataHandler( - inputName: String, - longIntervalDuration: Time, - shortIntervalDuration: Time, - blockManager: BlockManager - ) - extends Logging { - - class Block(val id: String, val shortInterval: Interval, val buffer: ByteBuffer) { - var pushed = false - def longInterval = getLongInterval(shortInterval) - override def toString() = "Block " + id - } - - class Bucket(val longInterval: Interval) { - val blocks = new ArrayBuffer[Block]() - var filled = false - def += (block: Block) = blocks += block - def empty() = (blocks.size == 0) - def ready() = (filled && !blocks.exists(! _.pushed)) - def blockIds() = blocks.map(_.id).toArray - override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" - } - - initLogging() - - val syncOnLastShortInterval = true - - val shortIntervalDurationMillis = shortIntervalDuration.milliseconds - val longIntervalDurationMillis = longIntervalDuration.milliseconds - - val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) - var currentShortInterval = Interval.currentInterval(shortIntervalDuration) - - val blocksForPushing = new Queue[Block]() - val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] - - val bufferProcessingThread = new Thread() { override def run() { keepProcessingBuffers() } } - val blockPushingExecutor = Executors.newFixedThreadPool(5) - - - def start() { - buffer.clear() - if (buffer.remaining == 0) { - throw new Exception("Buffer initialization error") - } - bufferProcessingThread.start() - } - - def readDataToBuffer(func: ByteBuffer => Int): Int = { - buffer.synchronized { - if (buffer.remaining == 0) { - logInfo("Received first data for interval " + currentShortInterval) - } - func(buffer) - } - } - - def getLongInterval(shortInterval: Interval): Interval = { - val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) - Interval(intervalBegin, intervalBegin + longIntervalDuration) - } - - def processBuffer() { - - def readInt(buffer: ByteBuffer): Int = { - var offset = 0 - var result = 0 - while (offset < 32) { - val b = buffer.get() - result |= ((b & 0x7F) << offset) - if ((b & 0x80) == 0) { - return result - } - offset += 7 - } - throw new Exception("Malformed zigzag-encoded integer") - } - - val currentLongInterval = getLongInterval(currentShortInterval) - val startTime = System.currentTimeMillis - val newBuffer: ByteBuffer = buffer.synchronized { - buffer.flip() - if (buffer.remaining == 0) { - buffer.clear() - null - } else { - logDebug("Processing interval " + currentShortInterval + " with delay of " + (System.currentTimeMillis - startTime) + " ms") - val startTime1 = System.currentTimeMillis - var loop = true - var count = 0 - while(loop) { - buffer.mark() - try { - val len = readInt(buffer) - buffer.position(buffer.position + len) - count += 1 - } catch { - case e: Exception => { - buffer.reset() - loop = false - } - } - } - val bytesToCopy = buffer.position - val newBuf = ByteBuffer.allocate(bytesToCopy) - buffer.position(0) - newBuf.put(buffer.slice().limit(bytesToCopy).asInstanceOf[ByteBuffer]) - newBuf.flip() - buffer.position(bytesToCopy) - buffer.compact() - newBuf - } - } - - if (newBuffer != null) { - val bucket = buckets.getOrElseUpdate(currentLongInterval, new Bucket(currentLongInterval)) - bucket.synchronized { - val newBlockId = inputName + "-" + currentLongInterval.toFormattedString + "-" + currentShortInterval.toFormattedString - val newBlock = new Block(newBlockId, currentShortInterval, newBuffer) - if (syncOnLastShortInterval) { - bucket += newBlock - } - logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.milliseconds) / 1000.0 + " s" ) - blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) - } - } - - val newShortInterval = Interval.currentInterval(shortIntervalDuration) - val newLongInterval = getLongInterval(newShortInterval) - - if (newLongInterval != currentLongInterval) { - buckets.get(currentLongInterval) match { - case Some(bucket) => { - bucket.synchronized { - bucket.filled = true - if (bucket.ready) { - bucket.notifyAll() - } - } - } - case None => - } - buckets += ((newLongInterval, new Bucket(newLongInterval))) - } - - currentShortInterval = newShortInterval - } - - def pushBlock(block: Block) { - try{ - if (blockManager != null) { - val startTime = System.currentTimeMillis - logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.milliseconds) + " ms") - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ - blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY)*/ - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER)*/ - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ - val finishTime = System.currentTimeMillis - logInfo(block + " put delay is " + (finishTime - startTime) + " ms") - } else { - logWarning(block + " not put as block manager is null") - } - } catch { - case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) - } - } - - def getBucket(longInterval: Interval): Option[Bucket] = { - buckets.get(longInterval) - } - - def clearBucket(longInterval: Interval) { - buckets.remove(longInterval) - } - - def keepProcessingBuffers() { - logInfo("Thread to process buffers started") - while(true) { - processBuffer() - val currentTimeMillis = System.currentTimeMillis - val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * - shortIntervalDurationMillis - currentTimeMillis + 1 - Thread.sleep(sleepTimeMillis) - } - } - - def pushAndNotifyBlock(block: Block) { - pushBlock(block) - block.pushed = true - val bucket = if (syncOnLastShortInterval) { - buckets(block.longInterval) - } else { - var longInterval = block.longInterval - while(!buckets.contains(longInterval)) { - logWarning("Skipping bucket of " + longInterval + " for " + block) - longInterval = longInterval.next - } - val chosenBucket = buckets(longInterval) - logDebug("Choosing bucket of " + longInterval + " for " + block) - chosenBucket += block - chosenBucket - } - - bucket.synchronized { - if (bucket.ready) { - bucket.notifyAll() - } - } - - } - } - - - class ReceivingConnectionHandler(host: String, port: Int, dataHandler: DataHandler) - extends ConnectionHandler(host, port, false) { - - override def ready(key: SelectionKey) { - changeInterest(key, SelectionKey.OP_READ) - } - - override def read(key: SelectionKey) { - try { - val channel = key.channel.asInstanceOf[SocketChannel] - val bytesRead = dataHandler.readDataToBuffer(channel.read) - if (bytesRead < 0) { - close(key) - } - } catch { - case e: IOException => { - logError("Error reading", e) - close(key) - } - } - } - } - - initLogging() - - val masterHost = System.getProperty("spark.master.host", "localhost") - val masterPort = System.getProperty("spark.master.port", "7078").toInt - - val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) - val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") - val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") - - logInfo("Getting stream details from master " + masterHost + ":" + masterPort) - - val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { - case Some(details) => details - case None => throw new Exception("Could not get stream details") - } - logInfo("Stream details received: " + streamDetails) - - val inputName = streamDetails.name - val intervalDurationMillis = streamDetails.duration - val intervalDuration = Milliseconds(intervalDurationMillis) - val shortIntervalDuration = Milliseconds(System.getProperty("spark.stream.shortinterval", "500").toInt) - - val dataHandler = new DataHandler(inputName, intervalDuration, shortIntervalDuration, blockManager) - val connectionHandler = new ReceivingConnectionHandler("localhost", 9999, dataHandler) - - val timeout = 100 millis - - // Send a message to an actor and return an option with its reply, or None if this times out - def askActor[T](actor: ActorRef, message: Any): Option[T] = { - try { - val future = actor.ask(message)(timeout) - return Some(Await.result(future, timeout).asInstanceOf[T]) - } catch { - case e: Exception => - logInfo("Error communicating with " + actor, e) - return None - } - } - - override def run() { - connectionHandler.start() - dataHandler.start() - - var interval = Interval.currentInterval(intervalDuration) - var dataStarted = false - - - while(true) { - waitFor(interval.endTime) - /*logInfo("Woken up at " + System.currentTimeMillis + " for " + interval)*/ - dataHandler.getBucket(interval) match { - case Some(bucket) => { - logDebug("Found " + bucket + " for " + interval) - bucket.synchronized { - if (!bucket.ready) { - logDebug("Waiting for " + bucket) - bucket.wait() - logDebug("Wait over for " + bucket) - } - if (dataStarted || !bucket.empty) { - logDebug("Notifying " + bucket) - notifyScheduler(interval, bucket.blockIds) - dataStarted = true - } - bucket.blocks.clear() - dataHandler.clearBucket(interval) - } - } - case None => { - logDebug("Found none for " + interval) - if (dataStarted) { - logDebug("Notifying none") - notifyScheduler(interval, Array[String]()) - } - } - } - interval = interval.next - } - } - - def waitFor(time: Time) { - val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.milliseconds - if (currentTimeMillis < targetTimeMillis) { - val sleepTime = (targetTimeMillis - currentTimeMillis) - Thread.sleep(sleepTime + 1) - } - } - - def notifyScheduler(interval: Interval, blockIds: Array[String]) { - try { - sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime - val delay = (System.currentTimeMillis - time.milliseconds) - logInfo("Notification delay for " + time + " is " + delay + " ms") - } catch { - case e: Exception => logError("Exception notifying scheduler at interval " + interval + ": " + e) - } - } -} - - -object TestStreamReceiver4 { - def main(args: Array[String]) { - val details = Array(("Sentences", 2000L)) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) - actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") - new TestStreamReceiver4(actorSystem, null).start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala b/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala new file mode 100644 index 0000000000..cde868a0c9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala @@ -0,0 +1,157 @@ +package spark.streaming.util + +import spark.Logging + +import scala.collection.mutable.{ArrayBuffer, SynchronizedQueue} + +import java.net._ +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ +import java.nio.channels.spi._ + +abstract class ConnectionHandler(host: String, port: Int, connect: Boolean) +extends Thread with Logging { + + val selector = SelectorProvider.provider.openSelector() + val interestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + + initLogging() + + override def run() { + try { + if (connect) { + connect() + } else { + listen() + } + + var interrupted = false + while(!interrupted) { + + preSelect() + + while(!interestChangeRequests.isEmpty) { + val (key, ops) = interestChangeRequests.dequeue + val lastOps = key.interestOps() + key.interestOps(ops) + + def intToOpStr(op: Int): String = { + val opStrs = new ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed ops from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } + + selector.select() + interrupted = Thread.currentThread.isInterrupted + + val selectedKeys = selector.selectedKeys().iterator() + while (selectedKeys.hasNext) { + val key = selectedKeys.next.asInstanceOf[SelectionKey] + selectedKeys.remove() + if (key.isValid) { + if (key.isAcceptable) { + accept(key) + } else if (key.isConnectable) { + finishConnect(key) + } else if (key.isReadable) { + read(key) + } else if (key.isWritable) { + write(key) + } + } + } + } + } catch { + case e: Exception => { + logError("Error in select loop", e) + } + } + } + + def connect() { + val socketAddress = new InetSocketAddress(host, port) + val channel = SocketChannel.open() + channel.configureBlocking(false) + channel.socket.setReuseAddress(true) + channel.socket.setTcpNoDelay(true) + channel.connect(socketAddress) + channel.register(selector, SelectionKey.OP_CONNECT) + logInfo("Initiating connection to [" + socketAddress + "]") + } + + def listen() { + val channel = ServerSocketChannel.open() + channel.configureBlocking(false) + channel.socket.setReuseAddress(true) + channel.socket.setReceiveBufferSize(256 * 1024) + channel.socket.bind(new InetSocketAddress(port)) + channel.register(selector, SelectionKey.OP_ACCEPT) + logInfo("Listening on port " + port) + } + + def finishConnect(key: SelectionKey) { + try { + val channel = key.channel.asInstanceOf[SocketChannel] + val address = channel.socket.getRemoteSocketAddress + channel.finishConnect() + logInfo("Connected to [" + host + ":" + port + "]") + ready(key) + } catch { + case e: IOException => { + logError("Error finishing connect to " + host + ":" + port) + close(key) + } + } + } + + def accept(key: SelectionKey) { + try { + val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] + val channel = serverChannel.accept() + val address = channel.socket.getRemoteSocketAddress + channel.configureBlocking(false) + logInfo("Accepted connection from [" + address + "]") + ready(channel.register(selector, 0)) + } catch { + case e: IOException => { + logError("Error accepting connection", e) + } + } + } + + def changeInterest(key: SelectionKey, ops: Int) { + logTrace("Added request to change ops to " + ops) + interestChangeRequests += ((key, ops)) + } + + def ready(key: SelectionKey) + + def preSelect() { + } + + def read(key: SelectionKey) { + throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) + } + + def write(key: SelectionKey) { + throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) + } + + def close(key: SelectionKey) { + try { + key.channel.close() + key.cancel() + Thread.currentThread.interrupt + } catch { + case e: Exception => logError("Error closing connection", e) + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala new file mode 100644 index 0000000000..23e9235c60 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala @@ -0,0 +1,107 @@ +package spark.streaming.util + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.net.InetSocketAddress + + +object TestGenerator { + + def printUsage { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + /* + def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 + val random = new Random () + + try { + var lastPrintTime = System.currentTimeMillis() + var count = 0 + while(true) { + streamReceiver ! lines(random.nextInt(lines.length)) + count += 1 + if (System.currentTimeMillis - lastPrintTime >= 1000) { + println (count + " sentences sent last second") + count = 0 + lastPrintTime = System.currentTimeMillis + } + Thread.sleep(sleepBetweenSentences.toLong) + } + } catch { + case e: Exception => + } + }*/ + + def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { + try { + val numSentences = if (sentencesPerSecond <= 0) { + lines.length + } else { + sentencesPerSecond + } + val sentences = lines.take(numSentences).toArray + + var nextSendingTime = System.currentTimeMillis() + val sendAsArray = true + while(true) { + if (sendAsArray) { + println("Sending as array") + streamReceiver !? sentences + } else { + println("Sending individually") + sentences.foreach(sentence => { + streamReceiver !? sentence + }) + } + println ("Sent " + numSentences + " sentences in " + (System.currentTimeMillis - nextSendingTime) + " ms") + nextSendingTime += 1000 + val sleepTime = nextSendingTime - System.currentTimeMillis + if (sleepTime > 0) { + println ("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } + } + } catch { + case e: Exception => + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + printUsage + } + + val generateRandomly = false + + val streamReceiverIP = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 + val sentenceInputName = if (args.length > 4) args(4) else "Sentences" + + println("Sending " + sentencesPerSecond + " sentences per second to " + + streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close () + + val streamReceiver = select( + Node(streamReceiverIP, streamReceiverPort), + Symbol("NetworkStreamReceiver-" + sentenceInputName)) + if (generateRandomly) { + /*generateRandomSentences(lines, sentencesPerSecond, streamReceiver)*/ + } else { + generateSameSentences(lines, sentencesPerSecond, streamReceiver) + } + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala new file mode 100644 index 0000000000..ff840d084f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala @@ -0,0 +1,119 @@ +package spark.streaming.util + +import scala.util.Random +import scala.io.Source +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ + +import java.io.{DataOutputStream, ByteArrayOutputStream, DataInputStream} +import java.net.Socket + +object TestGenerator2 { + + def printUsage { + println ("Usage: SentenceGenerator []") + System.exit(0) + } + + def sendSentences(streamReceiverHost: String, streamReceiverPort: Int, numSentences: Int, bytes: Array[Byte], intervalTime: Long){ + try { + println("Connecting to " + streamReceiverHost + ":" + streamReceiverPort) + val socket = new Socket(streamReceiverHost, streamReceiverPort) + + println("Sending " + numSentences+ " sentences / " + (bytes.length / 1024.0 / 1024.0) + " MB per " + intervalTime + " ms to " + streamReceiverHost + ":" + streamReceiverPort ) + val currentTime = System.currentTimeMillis + var targetTime = (currentTime / intervalTime + 1).toLong * intervalTime + Thread.sleep(targetTime - currentTime) + + while(true) { + val startTime = System.currentTimeMillis() + println("Sending at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") + val socketOutputStream = socket.getOutputStream + val parts = 10 + (0 until parts).foreach(i => { + val partStartTime = System.currentTimeMillis + + val offset = (i * bytes.length / parts).toInt + val len = math.min(((i + 1) * bytes.length / parts).toInt - offset, bytes.length) + socketOutputStream.write(bytes, offset, len) + socketOutputStream.flush() + val partFinishTime = System.currentTimeMillis + println("Sending part " + i + " of " + len + " bytes took " + (partFinishTime - partStartTime) + " ms") + val sleepTime = math.max(0, 1000 / parts - (partFinishTime - partStartTime) - 1) + Thread.sleep(sleepTime) + }) + + socketOutputStream.flush() + /*val socketInputStream = new DataInputStream(socket.getInputStream)*/ + /*val reply = socketInputStream.readUTF()*/ + val finishTime = System.currentTimeMillis() + println ("Sent " + bytes.length + " bytes in " + (finishTime - startTime) + " ms for interval [" + targetTime + ", " + (targetTime + intervalTime) + "]") + /*println("Received = " + reply)*/ + targetTime = targetTime + intervalTime + val sleepTime = (targetTime - finishTime) + 10 + if (sleepTime > 0) { + println("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } else { + println("############################") + println("###### Skipping sleep ######") + println("############################") + } + } + } catch { + case e: Exception => println(e) + } + println("Stopped sending") + } + + def main(args: Array[String]) { + if (args.length < 4) { + printUsage + } + + val streamReceiverHost = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val intervalTime = args(3).toLong + val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 + + println("Reading the file " + sentenceFile) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close() + + val numSentences = if (sentencesPerInterval <= 0) { + lines.length + } else { + sentencesPerInterval + } + + println("Generating sentences") + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take(numSentences).toArray + } else { + (0 until numSentences).map(i => lines(i % lines.length)).toArray + } + + println("Converting to byte array") + val byteStream = new ByteArrayOutputStream() + val stringDataStream = new DataOutputStream(byteStream) + /*stringDataStream.writeInt(sentences.size)*/ + sentences.foreach(stringDataStream.writeUTF) + val bytes = byteStream.toByteArray() + stringDataStream.close() + println("Generated array of " + bytes.length + " bytes") + + /*while(true) { */ + sendSentences(streamReceiverHost, streamReceiverPort, numSentences, bytes, intervalTime) + /*println("Sleeping for 5 seconds")*/ + /*Thread.sleep(5000)*/ + /*System.gc()*/ + /*}*/ + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala new file mode 100644 index 0000000000..9c39ef3e12 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala @@ -0,0 +1,244 @@ +package spark.streaming.util + +import spark.Logging + +import scala.util.Random +import scala.io.Source +import scala.collection.mutable.{ArrayBuffer, Queue} + +import java.net._ +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ + +import it.unimi.dsi.fastutil.io._ + +class TestGenerator4(targetHost: String, targetPort: Int, sentenceFile: String, intervalDuration: Long, sentencesPerInterval: Int) +extends Logging { + + class SendingConnectionHandler(host: String, port: Int, generator: TestGenerator4) + extends ConnectionHandler(host, port, true) { + + val buffers = new ArrayBuffer[ByteBuffer] + val newBuffers = new Queue[ByteBuffer] + var activeKey: SelectionKey = null + + def send(buffer: ByteBuffer) { + logDebug("Sending: " + buffer) + newBuffers.synchronized { + newBuffers.enqueue(buffer) + } + selector.wakeup() + buffer.synchronized { + buffer.wait() + } + } + + override def ready(key: SelectionKey) { + logDebug("Ready") + activeKey = key + val channel = key.channel.asInstanceOf[SocketChannel] + channel.register(selector, SelectionKey.OP_WRITE) + generator.startSending() + } + + override def preSelect() { + newBuffers.synchronized { + while(!newBuffers.isEmpty) { + val buffer = newBuffers.dequeue + buffers += buffer + logDebug("Added: " + buffer) + changeInterest(activeKey, SelectionKey.OP_WRITE) + } + } + } + + override def write(key: SelectionKey) { + try { + /*while(true) {*/ + val channel = key.channel.asInstanceOf[SocketChannel] + if (buffers.size > 0) { + val buffer = buffers(0) + val newBuffer = buffer.slice() + newBuffer.limit(math.min(newBuffer.remaining, 32768)) + val bytesWritten = channel.write(newBuffer) + buffer.position(buffer.position + bytesWritten) + if (bytesWritten == 0) return + if (buffer.remaining == 0) { + buffers -= buffer + buffer.synchronized { + buffer.notify() + } + } + /*changeInterest(key, SelectionKey.OP_WRITE)*/ + } else { + changeInterest(key, 0) + } + /*}*/ + } catch { + case e: IOException => { + if (e.toString.contains("pipe") || e.toString.contains("reset")) { + logError("Connection broken") + } else { + logError("Connection error", e) + } + close(key) + } + } + } + + override def close(key: SelectionKey) { + buffers.clear() + super.close(key) + } + } + + initLogging() + + val connectionHandler = new SendingConnectionHandler(targetHost, targetPort, this) + var sendingThread: Thread = null + var sendCount = 0 + val sendBatches = 5 + + def run() { + logInfo("Connection handler started") + connectionHandler.start() + connectionHandler.join() + if (sendingThread != null && !sendingThread.isInterrupted) { + sendingThread.interrupt + } + logInfo("Connection handler stopped") + } + + def startSending() { + sendingThread = new Thread() { + override def run() { + logInfo("STARTING TO SEND") + sendSentences() + logInfo("SENDING STOPPED AFTER " + sendCount) + connectionHandler.interrupt() + } + } + sendingThread.start() + } + + def stopSending() { + sendingThread.interrupt() + } + + def sendSentences() { + logInfo("Reading the file " + sentenceFile) + val source = Source.fromFile(sentenceFile) + val lines = source.mkString.split ("\n") + source.close() + + val numSentences = if (sentencesPerInterval <= 0) { + lines.length + } else { + sentencesPerInterval + } + + logInfo("Generating sentence buffer") + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take(numSentences).toArray + } else { + (0 until numSentences).map(i => lines(i % lines.length)).toArray + } + + /* + val sentences: Array[String] = if (numSentences <= lines.length) { + lines.take((numSentences / sendBatches).toInt).toArray + } else { + (0 until (numSentences/sendBatches)).map(i => lines(i % lines.length)).toArray + }*/ + + + val serializer = new spark.KryoSerializer().newInstance() + val byteStream = new FastByteArrayOutputStream(100 * 1024 * 1024) + serializer.serializeStream(byteStream).writeAll(sentences.toIterator.asInstanceOf[Iterator[Any]]).close() + byteStream.trim() + val sentenceBuffer = ByteBuffer.wrap(byteStream.array) + + logInfo("Sending " + numSentences+ " sentences / " + sentenceBuffer.limit + " bytes per " + intervalDuration + " ms to " + targetHost + ":" + targetPort ) + val currentTime = System.currentTimeMillis + var targetTime = (currentTime / intervalDuration + 1).toLong * intervalDuration + Thread.sleep(targetTime - currentTime) + + val totalBytes = sentenceBuffer.limit + + while(true) { + val batchesInCurrentInterval = sendBatches // if (sendCount < 10) 1 else sendBatches + + val startTime = System.currentTimeMillis() + logDebug("Sending # " + sendCount + " at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") + + (0 until batchesInCurrentInterval).foreach(i => { + try { + val position = (i * totalBytes / sendBatches).toInt + val limit = if (i == sendBatches - 1) { + totalBytes + } else { + ((i + 1) * totalBytes / sendBatches).toInt - 1 + } + + val partStartTime = System.currentTimeMillis + sentenceBuffer.limit(limit) + connectionHandler.send(sentenceBuffer) + val partFinishTime = System.currentTimeMillis + val sleepTime = math.max(0, intervalDuration / sendBatches - (partFinishTime - partStartTime) - 1) + Thread.sleep(sleepTime) + + } catch { + case ie: InterruptedException => return + case e: Exception => e.printStackTrace() + } + }) + sentenceBuffer.rewind() + + val finishTime = System.currentTimeMillis() + /*logInfo ("Sent " + sentenceBuffer.limit + " bytes in " + (finishTime - startTime) + " ms")*/ + targetTime = targetTime + intervalDuration //+ (if (sendCount < 3) 1000 else 0) + + val sleepTime = (targetTime - finishTime) + 20 + if (sleepTime > 0) { + logInfo("Sleeping for " + sleepTime + " ms") + Thread.sleep(sleepTime) + } else { + logInfo("###### Skipping sleep ######") + } + if (Thread.currentThread.isInterrupted) { + return + } + sendCount += 1 + } + } +} + +object TestGenerator4 { + def printUsage { + println("Usage: TestGenerator4 []") + System.exit(0) + } + + def main(args: Array[String]) { + println("GENERATOR STARTED") + if (args.length < 4) { + printUsage + } + + + val streamReceiverHost = args(0) + val streamReceiverPort = args(1).toInt + val sentenceFile = args(2) + val intervalDuration = args(3).toLong + val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 + + while(true) { + val generator = new TestGenerator4(streamReceiverHost, streamReceiverPort, sentenceFile, intervalDuration, sentencesPerInterval) + generator.run() + Thread.sleep(2000) + } + println("GENERATOR STOPPED") + } +} diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala new file mode 100644 index 0000000000..f584f772bb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala @@ -0,0 +1,39 @@ +package spark.streaming.util + +import spark.streaming._ +import spark.Logging + +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ + +sealed trait TestStreamCoordinatorMessage +case class GetStreamDetails extends TestStreamCoordinatorMessage +case class GotStreamDetails(name: String, duration: Long) extends TestStreamCoordinatorMessage +case class TestStarted extends TestStreamCoordinatorMessage + +class TestStreamCoordinator(streamDetails: Array[(String, Long)]) extends Actor with Logging { + + var index = 0 + + initLogging() + + logInfo("Created") + + def receive = { + case TestStarted => { + sender ! "OK" + } + + case GetStreamDetails => { + val streamDetail = if (index >= streamDetails.length) null else streamDetails(index) + sender ! GotStreamDetails(streamDetail._1, streamDetail._2) + index += 1 + if (streamDetail != null) { + logInfo("Allocated " + streamDetail._1 + " (" + index + "/" + streamDetails.length + ")" ) + } + } + } + +} + diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala new file mode 100644 index 0000000000..d00ae9cbca --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala @@ -0,0 +1,421 @@ +package spark.streaming.util + +import spark._ +import spark.storage._ +import spark.util.AkkaUtils +import spark.streaming._ + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} + +import akka.actor._ +import akka.actor.Actor +import akka.dispatch._ +import akka.pattern.ask +import akka.util.duration._ + +import java.io.DataInputStream +import java.io.BufferedInputStream +import java.net.Socket +import java.net.ServerSocket +import java.util.LinkedHashMap + +import org.apache.hadoop.fs._ +import org.apache.hadoop.conf._ +import org.apache.hadoop.io._ +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util._ + +import spark.Utils + + +class TestStreamReceiver3(actorSystem: ActorSystem, blockManager: BlockManager) +extends Thread with Logging { + + + class DataHandler( + inputName: String, + longIntervalDuration: Time, + shortIntervalDuration: Time, + blockManager: BlockManager + ) + extends Logging { + + class Block(var id: String, var shortInterval: Interval) { + val data = ArrayBuffer[String]() + var pushed = false + def longInterval = getLongInterval(shortInterval) + def empty() = (data.size == 0) + def += (str: String) = (data += str) + override def toString() = "Block " + id + } + + class Bucket(val longInterval: Interval) { + val blocks = new ArrayBuffer[Block]() + var filled = false + def += (block: Block) = blocks += block + def empty() = (blocks.size == 0) + def ready() = (filled && !blocks.exists(! _.pushed)) + def blockIds() = blocks.map(_.id).toArray + override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" + } + + initLogging() + + val shortIntervalDurationMillis = shortIntervalDuration.toLong + val longIntervalDurationMillis = longIntervalDuration.toLong + + var currentBlock: Block = null + var currentBucket: Bucket = null + + val blocksForPushing = new Queue[Block]() + val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] + + val blockUpdatingThread = new Thread() { override def run() { keepUpdatingCurrentBlock() } } + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + def start() { + blockUpdatingThread.start() + blockPushingThread.start() + } + + def += (data: String) = addData(data) + + def addData(data: String) { + if (currentBlock == null) { + updateCurrentBlock() + } + currentBlock.synchronized { + currentBlock += data + } + } + + def getShortInterval(time: Time): Interval = { + val intervalBegin = time.floor(shortIntervalDuration) + Interval(intervalBegin, intervalBegin + shortIntervalDuration) + } + + def getLongInterval(shortInterval: Interval): Interval = { + val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) + Interval(intervalBegin, intervalBegin + longIntervalDuration) + } + + def updateCurrentBlock() { + /*logInfo("Updating current block")*/ + val currentTime = Time(System.currentTimeMillis) + val shortInterval = getShortInterval(currentTime) + val longInterval = getLongInterval(shortInterval) + + def createBlock(reuseCurrentBlock: Boolean = false) { + val newBlockId = inputName + "-" + longInterval.toFormattedString + "-" + currentBucket.blocks.size + if (!reuseCurrentBlock) { + val newBlock = new Block(newBlockId, shortInterval) + /*logInfo("Created " + currentBlock)*/ + currentBlock = newBlock + } else { + currentBlock.shortInterval = shortInterval + currentBlock.id = newBlockId + } + } + + def createBucket() { + val newBucket = new Bucket(longInterval) + buckets += ((longInterval, newBucket)) + currentBucket = newBucket + /*logInfo("Created " + currentBucket + ", " + buckets.size + " buckets")*/ + } + + if (currentBlock == null || currentBucket == null) { + createBucket() + currentBucket.synchronized { + createBlock() + } + return + } + + currentBlock.synchronized { + var reuseCurrentBlock = false + + if (shortInterval != currentBlock.shortInterval) { + if (!currentBlock.empty) { + blocksForPushing.synchronized { + blocksForPushing += currentBlock + blocksForPushing.notifyAll() + } + } + + currentBucket.synchronized { + if (currentBlock.empty) { + reuseCurrentBlock = true + } else { + currentBucket += currentBlock + } + + if (longInterval != currentBucket.longInterval) { + currentBucket.filled = true + if (currentBucket.ready) { + currentBucket.notifyAll() + } + createBucket() + } + } + + createBlock(reuseCurrentBlock) + } + } + } + + def pushBlock(block: Block) { + try{ + if (blockManager != null) { + logInfo("Pushing block") + val startTime = System.currentTimeMillis + + val bytes = blockManager.dataSerialize(block.data.toIterator) + val finishTime = System.currentTimeMillis + logInfo(block + " serialization delay is " + (finishTime - startTime) / 1000.0 + " s") + + blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_2) + /*blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ + /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY_DESER)*/ + /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY)*/ + val finishTime1 = System.currentTimeMillis + logInfo(block + " put delay is " + (finishTime1 - startTime) / 1000.0 + " s") + } else { + logWarning(block + " not put as block manager is null") + } + } catch { + case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) + } + } + + def getBucket(longInterval: Interval): Option[Bucket] = { + buckets.get(longInterval) + } + + def clearBucket(longInterval: Interval) { + buckets.remove(longInterval) + } + + def keepUpdatingCurrentBlock() { + logInfo("Thread to update current block started") + while(true) { + updateCurrentBlock() + val currentTimeMillis = System.currentTimeMillis + val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * + shortIntervalDurationMillis - currentTimeMillis + 1 + Thread.sleep(sleepTimeMillis) + } + } + + def keepPushingBlocks() { + var loop = true + logInfo("Thread to push blocks started") + while(loop) { + val block = blocksForPushing.synchronized { + if (blocksForPushing.size == 0) { + blocksForPushing.wait() + } + blocksForPushing.dequeue + } + pushBlock(block) + block.pushed = true + block.data.clear() + + val bucket = buckets(block.longInterval) + bucket.synchronized { + if (bucket.ready) { + bucket.notifyAll() + } + } + } + } + } + + + class ConnectionListener(port: Int, dataHandler: DataHandler) + extends Thread with Logging { + initLogging() + override def run { + try { + val listener = new ServerSocket(port) + logInfo("Listening on port " + port) + while (true) { + new ConnectionHandler(listener.accept(), dataHandler).start(); + } + listener.close() + } catch { + case e: Exception => logError("", e); + } + } + } + + class ConnectionHandler(socket: Socket, dataHandler: DataHandler) extends Thread with Logging { + initLogging() + override def run { + logInfo("New connection from " + socket.getInetAddress() + ":" + socket.getPort) + val bytes = new Array[Byte](100 * 1024 * 1024) + try { + + val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, 1024 * 1024)) + /*val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream))*/ + var str: String = null + str = inputStream.readUTF + while(str != null) { + dataHandler += str + str = inputStream.readUTF() + } + + /* + var loop = true + while(loop) { + val numRead = inputStream.read(bytes) + if (numRead < 0) { + loop = false + } + inbox += ((LongTime(SystemTime.currentTimeMillis), "test")) + }*/ + + inputStream.close() + } catch { + case e => logError("Error receiving data", e) + } + socket.close() + } + } + + initLogging() + + val masterHost = System.getProperty("spark.master.host") + val masterPort = System.getProperty("spark.master.port").toInt + + val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) + val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") + val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") + + logInfo("Getting stream details from master " + masterHost + ":" + masterPort) + + val timeout = 50 millis + + var started = false + while (!started) { + askActor[String](testStreamCoordinator, TestStarted) match { + case Some(str) => { + started = true + logInfo("TestStreamCoordinator started") + } + case None => { + logInfo("TestStreamCoordinator not started yet") + Thread.sleep(200) + } + } + } + + val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { + case Some(details) => details + case None => throw new Exception("Could not get stream details") + } + logInfo("Stream details received: " + streamDetails) + + val inputName = streamDetails.name + val intervalDurationMillis = streamDetails.duration + val intervalDuration = Time(intervalDurationMillis) + + val dataHandler = new DataHandler( + inputName, + intervalDuration, + Time(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), + blockManager) + + val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) + + // Send a message to an actor and return an option with its reply, or None if this times out + def askActor[T](actor: ActorRef, message: Any): Option[T] = { + try { + val future = actor.ask(message)(timeout) + return Some(Await.result(future, timeout).asInstanceOf[T]) + } catch { + case e: Exception => + logInfo("Error communicating with " + actor, e) + return None + } + } + + override def run() { + connListener.start() + dataHandler.start() + + var interval = Interval.currentInterval(intervalDuration) + var dataStarted = false + + while(true) { + waitFor(interval.endTime) + logInfo("Woken up at " + System.currentTimeMillis + " for " + interval) + dataHandler.getBucket(interval) match { + case Some(bucket) => { + logInfo("Found " + bucket + " for " + interval) + bucket.synchronized { + if (!bucket.ready) { + logInfo("Waiting for " + bucket) + bucket.wait() + logInfo("Wait over for " + bucket) + } + if (dataStarted || !bucket.empty) { + logInfo("Notifying " + bucket) + notifyScheduler(interval, bucket.blockIds) + dataStarted = true + } + bucket.blocks.clear() + dataHandler.clearBucket(interval) + } + } + case None => { + logInfo("Found none for " + interval) + if (dataStarted) { + logInfo("Notifying none") + notifyScheduler(interval, Array[String]()) + } + } + } + interval = interval.next + } + } + + def waitFor(time: Time) { + val currentTimeMillis = System.currentTimeMillis + val targetTimeMillis = time.milliseconds + if (currentTimeMillis < targetTimeMillis) { + val sleepTime = (targetTimeMillis - currentTimeMillis) + Thread.sleep(sleepTime + 1) + } + } + + def notifyScheduler(interval: Interval, blockIds: Array[String]) { + try { + sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) + val time = interval.endTime + val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 + logInfo("Pushing delay for " + time + " is " + delay + " s") + } catch { + case _ => logError("Exception notifying scheduler at interval " + interval) + } + } +} + +object TestStreamReceiver3 { + + val PORT = 9999 + val SHORT_INTERVAL_MILLIS = 100 + + def main(args: Array[String]) { + System.setProperty("spark.master.host", Utils.localHostName) + System.setProperty("spark.master.port", "7078") + val details = Array(("Sentences", 2000L)) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) + actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") + new TestStreamReceiver3(actorSystem, null).start() + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala new file mode 100644 index 0000000000..31754870dd --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala @@ -0,0 +1,374 @@ +package spark.streaming.util + +import spark.streaming._ +import spark._ +import spark.storage._ +import spark.util.AkkaUtils + +import scala.math._ +import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} + +import java.io._ +import java.nio._ +import java.nio.charset._ +import java.nio.channels._ +import java.util.concurrent.Executors + +import akka.actor._ +import akka.actor.Actor +import akka.dispatch._ +import akka.pattern.ask +import akka.util.duration._ + +class TestStreamReceiver4(actorSystem: ActorSystem, blockManager: BlockManager) +extends Thread with Logging { + + class DataHandler( + inputName: String, + longIntervalDuration: Time, + shortIntervalDuration: Time, + blockManager: BlockManager + ) + extends Logging { + + class Block(val id: String, val shortInterval: Interval, val buffer: ByteBuffer) { + var pushed = false + def longInterval = getLongInterval(shortInterval) + override def toString() = "Block " + id + } + + class Bucket(val longInterval: Interval) { + val blocks = new ArrayBuffer[Block]() + var filled = false + def += (block: Block) = blocks += block + def empty() = (blocks.size == 0) + def ready() = (filled && !blocks.exists(! _.pushed)) + def blockIds() = blocks.map(_.id).toArray + override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" + } + + initLogging() + + val syncOnLastShortInterval = true + + val shortIntervalDurationMillis = shortIntervalDuration.milliseconds + val longIntervalDurationMillis = longIntervalDuration.milliseconds + + val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) + var currentShortInterval = Interval.currentInterval(shortIntervalDuration) + + val blocksForPushing = new Queue[Block]() + val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] + + val bufferProcessingThread = new Thread() { override def run() { keepProcessingBuffers() } } + val blockPushingExecutor = Executors.newFixedThreadPool(5) + + + def start() { + buffer.clear() + if (buffer.remaining == 0) { + throw new Exception("Buffer initialization error") + } + bufferProcessingThread.start() + } + + def readDataToBuffer(func: ByteBuffer => Int): Int = { + buffer.synchronized { + if (buffer.remaining == 0) { + logInfo("Received first data for interval " + currentShortInterval) + } + func(buffer) + } + } + + def getLongInterval(shortInterval: Interval): Interval = { + val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) + Interval(intervalBegin, intervalBegin + longIntervalDuration) + } + + def processBuffer() { + + def readInt(buffer: ByteBuffer): Int = { + var offset = 0 + var result = 0 + while (offset < 32) { + val b = buffer.get() + result |= ((b & 0x7F) << offset) + if ((b & 0x80) == 0) { + return result + } + offset += 7 + } + throw new Exception("Malformed zigzag-encoded integer") + } + + val currentLongInterval = getLongInterval(currentShortInterval) + val startTime = System.currentTimeMillis + val newBuffer: ByteBuffer = buffer.synchronized { + buffer.flip() + if (buffer.remaining == 0) { + buffer.clear() + null + } else { + logDebug("Processing interval " + currentShortInterval + " with delay of " + (System.currentTimeMillis - startTime) + " ms") + val startTime1 = System.currentTimeMillis + var loop = true + var count = 0 + while(loop) { + buffer.mark() + try { + val len = readInt(buffer) + buffer.position(buffer.position + len) + count += 1 + } catch { + case e: Exception => { + buffer.reset() + loop = false + } + } + } + val bytesToCopy = buffer.position + val newBuf = ByteBuffer.allocate(bytesToCopy) + buffer.position(0) + newBuf.put(buffer.slice().limit(bytesToCopy).asInstanceOf[ByteBuffer]) + newBuf.flip() + buffer.position(bytesToCopy) + buffer.compact() + newBuf + } + } + + if (newBuffer != null) { + val bucket = buckets.getOrElseUpdate(currentLongInterval, new Bucket(currentLongInterval)) + bucket.synchronized { + val newBlockId = inputName + "-" + currentLongInterval.toFormattedString + "-" + currentShortInterval.toFormattedString + val newBlock = new Block(newBlockId, currentShortInterval, newBuffer) + if (syncOnLastShortInterval) { + bucket += newBlock + } + logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.milliseconds) / 1000.0 + " s" ) + blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) + } + } + + val newShortInterval = Interval.currentInterval(shortIntervalDuration) + val newLongInterval = getLongInterval(newShortInterval) + + if (newLongInterval != currentLongInterval) { + buckets.get(currentLongInterval) match { + case Some(bucket) => { + bucket.synchronized { + bucket.filled = true + if (bucket.ready) { + bucket.notifyAll() + } + } + } + case None => + } + buckets += ((newLongInterval, new Bucket(newLongInterval))) + } + + currentShortInterval = newShortInterval + } + + def pushBlock(block: Block) { + try{ + if (blockManager != null) { + val startTime = System.currentTimeMillis + logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.milliseconds) + " ms") + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ + blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER)*/ + /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ + val finishTime = System.currentTimeMillis + logInfo(block + " put delay is " + (finishTime - startTime) + " ms") + } else { + logWarning(block + " not put as block manager is null") + } + } catch { + case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) + } + } + + def getBucket(longInterval: Interval): Option[Bucket] = { + buckets.get(longInterval) + } + + def clearBucket(longInterval: Interval) { + buckets.remove(longInterval) + } + + def keepProcessingBuffers() { + logInfo("Thread to process buffers started") + while(true) { + processBuffer() + val currentTimeMillis = System.currentTimeMillis + val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * + shortIntervalDurationMillis - currentTimeMillis + 1 + Thread.sleep(sleepTimeMillis) + } + } + + def pushAndNotifyBlock(block: Block) { + pushBlock(block) + block.pushed = true + val bucket = if (syncOnLastShortInterval) { + buckets(block.longInterval) + } else { + var longInterval = block.longInterval + while(!buckets.contains(longInterval)) { + logWarning("Skipping bucket of " + longInterval + " for " + block) + longInterval = longInterval.next + } + val chosenBucket = buckets(longInterval) + logDebug("Choosing bucket of " + longInterval + " for " + block) + chosenBucket += block + chosenBucket + } + + bucket.synchronized { + if (bucket.ready) { + bucket.notifyAll() + } + } + + } + } + + + class ReceivingConnectionHandler(host: String, port: Int, dataHandler: DataHandler) + extends ConnectionHandler(host, port, false) { + + override def ready(key: SelectionKey) { + changeInterest(key, SelectionKey.OP_READ) + } + + override def read(key: SelectionKey) { + try { + val channel = key.channel.asInstanceOf[SocketChannel] + val bytesRead = dataHandler.readDataToBuffer(channel.read) + if (bytesRead < 0) { + close(key) + } + } catch { + case e: IOException => { + logError("Error reading", e) + close(key) + } + } + } + } + + initLogging() + + val masterHost = System.getProperty("spark.master.host", "localhost") + val masterPort = System.getProperty("spark.master.port", "7078").toInt + + val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) + val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") + val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") + + logInfo("Getting stream details from master " + masterHost + ":" + masterPort) + + val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { + case Some(details) => details + case None => throw new Exception("Could not get stream details") + } + logInfo("Stream details received: " + streamDetails) + + val inputName = streamDetails.name + val intervalDurationMillis = streamDetails.duration + val intervalDuration = Milliseconds(intervalDurationMillis) + val shortIntervalDuration = Milliseconds(System.getProperty("spark.stream.shortinterval", "500").toInt) + + val dataHandler = new DataHandler(inputName, intervalDuration, shortIntervalDuration, blockManager) + val connectionHandler = new ReceivingConnectionHandler("localhost", 9999, dataHandler) + + val timeout = 100 millis + + // Send a message to an actor and return an option with its reply, or None if this times out + def askActor[T](actor: ActorRef, message: Any): Option[T] = { + try { + val future = actor.ask(message)(timeout) + return Some(Await.result(future, timeout).asInstanceOf[T]) + } catch { + case e: Exception => + logInfo("Error communicating with " + actor, e) + return None + } + } + + override def run() { + connectionHandler.start() + dataHandler.start() + + var interval = Interval.currentInterval(intervalDuration) + var dataStarted = false + + + while(true) { + waitFor(interval.endTime) + /*logInfo("Woken up at " + System.currentTimeMillis + " for " + interval)*/ + dataHandler.getBucket(interval) match { + case Some(bucket) => { + logDebug("Found " + bucket + " for " + interval) + bucket.synchronized { + if (!bucket.ready) { + logDebug("Waiting for " + bucket) + bucket.wait() + logDebug("Wait over for " + bucket) + } + if (dataStarted || !bucket.empty) { + logDebug("Notifying " + bucket) + notifyScheduler(interval, bucket.blockIds) + dataStarted = true + } + bucket.blocks.clear() + dataHandler.clearBucket(interval) + } + } + case None => { + logDebug("Found none for " + interval) + if (dataStarted) { + logDebug("Notifying none") + notifyScheduler(interval, Array[String]()) + } + } + } + interval = interval.next + } + } + + def waitFor(time: Time) { + val currentTimeMillis = System.currentTimeMillis + val targetTimeMillis = time.milliseconds + if (currentTimeMillis < targetTimeMillis) { + val sleepTime = (targetTimeMillis - currentTimeMillis) + Thread.sleep(sleepTime + 1) + } + } + + def notifyScheduler(interval: Interval, blockIds: Array[String]) { + try { + sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) + val time = interval.endTime + val delay = (System.currentTimeMillis - time.milliseconds) + logInfo("Notification delay for " + time + " is " + delay + " ms") + } catch { + case e: Exception => logError("Exception notifying scheduler at interval " + interval + ": " + e) + } + } +} + + +object TestStreamReceiver4 { + def main(args: Array[String]) { + val details = Array(("Sentences", 2000L)) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) + actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") + new TestStreamReceiver4(actorSystem, null).start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/util/Utils.scala b/streaming/src/main/scala/spark/streaming/util/Utils.scala deleted file mode 100644 index 86a729fb49..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/Utils.scala +++ /dev/null @@ -1,9 +0,0 @@ -package spark.streaming.util - -object Utils { - def time(func: => Unit): Long = { - val t = System.currentTimeMillis - func - (System.currentTimeMillis - t) - } -} \ No newline at end of file -- cgit v1.2.3 From 650d11817eb15c1c2a8dc322b72c753df88bf8d3 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 2 Aug 2012 11:09:43 -0400 Subject: Added a WordCount for external data and fixed bugs in file streams --- .../main/scala/spark/streaming/FileInputRDS.scala | 13 +++++------ .../main/scala/spark/streaming/JobManager.scala | 2 +- .../scala/spark/streaming/examples/WordCount.scala | 25 ++++++++++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount.scala diff --git a/streaming/src/main/scala/spark/streaming/FileInputRDS.scala b/streaming/src/main/scala/spark/streaming/FileInputRDS.scala index dde80cd27a..ebd246823d 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputRDS.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputRDS.scala @@ -43,12 +43,11 @@ class FileInputRDS[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] var latestModTime = 0L def accept(path: Path): Boolean = { - if (!filter.accept(path)) { return false } else { val modTime = fs.getFileStatus(path).getModificationTime() - if (modTime < lastModTime) { + if (modTime <= lastModTime) { return false } if (modTime > latestModTime) { @@ -60,10 +59,12 @@ class FileInputRDS[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] } val newFiles = fs.listStatus(directory, newFilter) - lastModTime = newFilter.latestModTime - val newRDD = new UnionRDD(ssc.sc, newFiles.map(file => - ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString)) - ) + logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) + if (newFiles.length > 0) { + lastModTime = newFilter.latestModTime + } + val newRDD = new UnionRDD(ssc.sc, newFiles.map( + file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) Some(newRDD) } } diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 43d167f7db..c37fe1e9ad 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -13,7 +13,7 @@ class JobManager(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { try { val timeTaken = job.run() logInfo( - "Runnning " + job + " took " + timeTaken + " ms, " + + "Running " + job + " took " + timeTaken + " ms, " + "total delay was " + (System.currentTimeMillis - job.time) + " ms" ) } catch { diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala new file mode 100644 index 0000000000..a155630151 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala @@ -0,0 +1,25 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, SparkStreamContext} +import spark.streaming.SparkStreamContext._ + +object WordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: WordCount ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new SparkStreamContext(args(0), "ExampleTwo") + ssc.setBatchDuration(Seconds(2)) + + // Create the FileInputRDS on the directory and use the + // stream to count words in new files created + val inputRDS = ssc.createTextFileStream(args(1)) + val wordsRDS = inputRDS.flatMap(_.split(" ")) + val wordCountsRDS = wordsRDS.map(x => (x, 1)).reduceByKey(_ + _) + wordCountsRDS.print() + ssc.start() + } +} -- cgit v1.2.3 From 29bf44473c9d76622628f2511588f7846e9b1f3c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 2 Aug 2012 11:43:04 -0400 Subject: Added an RDS that repeatedly returns the same input --- .../main/scala/spark/streaming/ConstantInputRDS.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 streaming/src/main/scala/spark/streaming/ConstantInputRDS.scala diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputRDS.scala b/streaming/src/main/scala/spark/streaming/ConstantInputRDS.scala new file mode 100644 index 0000000000..bf2e6f7e16 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ConstantInputRDS.scala @@ -0,0 +1,18 @@ +package spark.streaming + +import spark.RDD + +/** + * An input stream that always returns the same RDD on each timestep. Useful for testing. + */ +class ConstantInputRDS[T: ClassManifest](ssc: SparkStreamContext, rdd: RDD[T]) + extends InputRDS[T](ssc) { + + override def start() {} + + override def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + Some(rdd) + } +} \ No newline at end of file -- cgit v1.2.3 From 43b81eb2719c4666b7869d7d0290f2ee83daeafa Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 2 Aug 2012 14:05:51 -0400 Subject: Renamed RDS to DStream, plus minor style fixes --- .../spark/streaming/ConstantInputDStream.scala | 18 + .../scala/spark/streaming/ConstantInputRDS.scala | 18 - .../src/main/scala/spark/streaming/DStream.scala | 486 +++++++++++++++++++++ .../scala/spark/streaming/FileInputDStream.scala | 164 +++++++ .../main/scala/spark/streaming/FileInputRDS.scala | 164 ------- .../src/main/scala/spark/streaming/Interval.scala | 7 +- .../spark/streaming/PairDStreamFunctions.scala | 72 +++ .../scala/spark/streaming/PairRDSFunctions.scala | 72 --- .../scala/spark/streaming/QueueInputDStream.scala | 36 ++ .../main/scala/spark/streaming/QueueInputRDS.scala | 36 -- streaming/src/main/scala/spark/streaming/RDS.scala | 484 -------------------- .../spark/streaming/ReducedWindowedDStream.scala | 218 +++++++++ .../scala/spark/streaming/ReducedWindowedRDS.scala | 218 --------- .../src/main/scala/spark/streaming/Scheduler.scala | 15 +- .../scala/spark/streaming/SparkStreamContext.scala | 58 +-- .../src/main/scala/spark/streaming/Time.scala | 6 +- .../scala/spark/streaming/WindowedDStream.scala | 68 +++ .../main/scala/spark/streaming/WindowedRDS.scala | 68 --- .../spark/streaming/examples/ExampleOne.scala | 4 +- .../spark/streaming/examples/ExampleTwo.scala | 10 +- .../scala/spark/streaming/examples/WordCount.scala | 10 +- .../spark/streaming/util/SenderReceiverTest.scala | 2 +- .../test/scala/spark/streaming/DStreamSuite.scala | 65 +++ .../src/test/scala/spark/streaming/RDSSuite.scala | 65 --- 24 files changed, 1181 insertions(+), 1183 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ConstantInputRDS.scala create mode 100644 streaming/src/main/scala/spark/streaming/DStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/FileInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FileInputRDS.scala create mode 100644 streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala delete mode 100644 streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala create mode 100644 streaming/src/main/scala/spark/streaming/QueueInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/QueueInputRDS.scala delete mode 100644 streaming/src/main/scala/spark/streaming/RDS.scala create mode 100644 streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ReducedWindowedRDS.scala create mode 100644 streaming/src/main/scala/spark/streaming/WindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WindowedRDS.scala create mode 100644 streaming/src/test/scala/spark/streaming/DStreamSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/RDSSuite.scala diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala new file mode 100644 index 0000000000..6a2be34633 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala @@ -0,0 +1,18 @@ +package spark.streaming + +import spark.RDD + +/** + * An input stream that always returns the same RDD on each timestep. Useful for testing. + */ +class ConstantInputDStream[T: ClassManifest](ssc: SparkStreamContext, rdd: RDD[T]) + extends InputDStream[T](ssc) { + + override def start() {} + + override def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + Some(rdd) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputRDS.scala b/streaming/src/main/scala/spark/streaming/ConstantInputRDS.scala deleted file mode 100644 index bf2e6f7e16..0000000000 --- a/streaming/src/main/scala/spark/streaming/ConstantInputRDS.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark.streaming - -import spark.RDD - -/** - * An input stream that always returns the same RDD on each timestep. Useful for testing. - */ -class ConstantInputRDS[T: ClassManifest](ssc: SparkStreamContext, rdd: RDD[T]) - extends InputRDS[T](ssc) { - - override def start() {} - - override def stop() {} - - override def compute(validTime: Time): Option[RDD[T]] = { - Some(rdd) - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala new file mode 100644 index 0000000000..e19d2ecef5 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -0,0 +1,486 @@ +package spark.streaming + +import spark.streaming.SparkStreamContext._ + +import spark.RDD +import spark.BlockRDD +import spark.UnionRDD +import spark.Logging +import spark.SparkContext +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import java.util.concurrent.ArrayBlockingQueue + +abstract class DStream[T: ClassManifest] (@transient val ssc: SparkStreamContext) +extends Logging with Serializable { + + initLogging() + + /** + * ---------------------------------------------- + * Methods that must be implemented by subclasses + * ---------------------------------------------- + */ + + // Time by which the window slides in this DStream + def slideTime: Time + + // List of parent DStreams on which this DStream depends on + def dependencies: List[DStream[_]] + + // Key method that computes RDD for a valid time + def compute (validTime: Time): Option[RDD[T]] + + /** + * --------------------------------------- + * Other general fields and methods of DStream + * --------------------------------------- + */ + + // Variable to store the RDDs generated earlier in time + @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () + + // Variable to be set to the first time seen by the DStream (effective time zero) + private[streaming] var zeroTime: Time = null + + // Variable to specify storage level + private var storageLevel: StorageLevel = StorageLevel.NONE + + // Checkpoint level and checkpoint interval + private var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint + private var checkpointInterval: Time = null + + // Change this RDD's storage level + def persist( + storageLevel: StorageLevel, + checkpointLevel: StorageLevel, + checkpointInterval: Time): DStream[T] = { + if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { + // TODO: not sure this is necessary for DStreams + throw new UnsupportedOperationException( + "Cannot change storage level of an DStream after it was already assigned a level") + } + this.storageLevel = storageLevel + this.checkpointLevel = checkpointLevel + this.checkpointInterval = checkpointInterval + this + } + + // Set caching level for the RDDs created by this DStream + def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null) + + def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_DESER) + + // Turn on the default caching level for this RDD + def cache(): DStream[T] = persist() + + def isInitialized = (zeroTime != null) + + /** + * This method initializes the DStream by setting the "zero" time, based on which + * the validity of future times is calculated. This method also recursively initializes + * its parent DStreams. + */ + def initialize(time: Time) { + if (zeroTime == null) { + zeroTime = time + } + logInfo(this + " initialized") + dependencies.foreach(_.initialize(zeroTime)) + } + + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ + private def isTimeValid (time: Time): Boolean = { + if (!isInitialized) + throw new Exception (this.toString + " has not been initialized") + if ((time - zeroTime).isMultipleOf(slideTime)) { + true + } else { + false + } + } + + /** + * This method either retrieves a precomputed RDD of this DStream, + * or computes the RDD (if the time is valid) + */ + def getOrCompute(time: Time): Option[RDD[T]] = { + // If this DStream was not initialized (i.e., zeroTime not set), then do it + // If RDD was already generated, then retrieve it from HashMap + generatedRDDs.get(time) match { + + // If an RDD was already generated and is being reused, then + // probably all RDDs in this DStream will be reused and hence should be cached + case Some(oldRDD) => Some(oldRDD) + + // if RDD was not generated, and if the time is valid + // (based on sliding time of this DStream), then generate the RDD + case None => + if (isTimeValid(time)) { + compute(time) match { + case Some(newRDD) => + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + } + generatedRDDs.put(time.copy(), newRDD) + Some(newRDD) + case None => + None + } + } else { + None + } + } + } + + /** + * This method generates a SparkStream job for the given time + * and may require to be overriden by subclasses + */ + def generateJob(time: Time): Option[Job] = { + getOrCompute(time) match { + case Some(rdd) => { + val jobFunc = () => { + val emptyFunc = { (iterator: Iterator[T]) => {} } + ssc.sc.runJob(rdd, emptyFunc) + } + Some(new Job(time, jobFunc)) + } + case None => None + } + } + + /** + * -------------- + * DStream operations + * -------------- + */ + + def map[U: ClassManifest](mapFunc: T => U) = new MappedDStream(this, ssc.sc.clean(mapFunc)) + + def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = + new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) + + def filter(filterFunc: T => Boolean) = new FilteredDStream(this, filterFunc) + + def glom() = new GlommedDStream(this) + + def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = + new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc)) + + def reduce(reduceFunc: (T, T) => T) = this.map(x => (1, x)).reduceByKey(reduceFunc, 1).map(_._2) + + def count() = this.map(_ => 1).reduce(_ + _) + + def collect() = this.map(x => (1, x)).groupByKey(1).map(_._2) + + def foreach(foreachFunc: T => Unit) = { + val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newStream) + newStream + } + + def foreachRDD(foreachFunc: RDD[T] => Unit) = { + val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newStream) + newStream + } + + private[streaming] def toQueue = { + val queue = new ArrayBlockingQueue[RDD[T]](10000) + this.foreachRDD(rdd => { + println("Added RDD " + rdd.id) + queue.add(rdd) + }) + queue + } + + def print() = { + def foreachFunc = (rdd: RDD[T], time: Time) => { + val first11 = rdd.take(11) + println ("-------------------------------------------") + println ("Time: " + time) + println ("-------------------------------------------") + first11.take(10).foreach(println) + if (first11.size > 10) println("...") + println() + } + val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + ssc.registerOutputStream(newStream) + newStream + } + + def window(windowTime: Time, slideTime: Time) = new WindowedDStream(this, windowTime, slideTime) + + def batch(batchTime: Time) = window(batchTime, batchTime) + + def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time) = + this.window(windowTime, slideTime).reduce(reduceFunc) + + def reduceByWindow( + reduceFunc: (T, T) => T, + invReduceFunc: (T, T) => T, + windowTime: Time, + slideTime: Time) = { + this.map(x => (1, x)) + .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1) + .map(_._2) + } + + def countByWindow(windowTime: Time, slideTime: Time) = { + def add(v1: Int, v2: Int) = (v1 + v2) + def subtract(v1: Int, v2: Int) = (v1 - v2) + this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime) + } + + def union(that: DStream[T]) = new UnifiedDStream(Array(this, that)) + + def register() { + ssc.registerOutputStream(this) + } +} + + +abstract class InputDStream[T: ClassManifest] ( + ssc: SparkStreamContext) +extends DStream[T](ssc) { + + override def dependencies = List() + + override def slideTime = ssc.batchDuration + + def start() + + def stop() +} + + +/** + * TODO + */ + +class MappedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + mapFunc: T => U) +extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.map[U](mapFunc)) + } +} + + +/** + * TODO + */ + +class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + flatMapFunc: T => Traversable[U]) +extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) + } +} + + +/** + * TODO + */ + +class FilteredDStream[T: ClassManifest](parent: DStream[T], filterFunc: T => Boolean) +extends DStream[T](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + parent.getOrCompute(validTime).map(_.filter(filterFunc)) + } +} + + +/** + * TODO + */ + +class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + mapPartFunc: Iterator[T] => Iterator[U]) +extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc)) + } +} + + +/** + * TODO + */ + +class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Array[T]]] = { + parent.getOrCompute(validTime).map(_.glom()) + } +} + + +/** + * TODO + */ + +class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( + parent: DStream[(K,V)], + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + numPartitions: Int) + extends DStream [(K,C)] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K,C)]] = { + parent.getOrCompute(validTime) match { + case Some(rdd) => + val newrdd = { + if (numPartitions > 0) { + rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, numPartitions) + } else { + rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner) + } + } + Some(newrdd) + case None => None + } + } +} + + +/** + * TODO + */ + +class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) +extends DStream[T](parents(0).ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different SparkStreamContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime: Time = parents(0).slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val rdds = new ArrayBuffer[RDD[T]]() + parents.map(_.getOrCompute(validTime)).foreach(_ match { + case Some(rdd) => rdds += rdd + case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) + }) + if (rdds.size > 0) { + Some(new UnionRDD(ssc.sc, rdds)) + } else { + None + } + } +} + + +/** + * TODO + */ + +class PerElementForEachDStream[T: ClassManifest] ( + parent: DStream[T], + foreachFunc: T => Unit) +extends DStream[Unit](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + val sparkJobFunc = { + (iterator: Iterator[T]) => iterator.foreach(foreachFunc) + } + ssc.sc.runJob(rdd, sparkJobFunc) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} + + +/** + * TODO + */ + +class PerRDDForEachDStream[T: ClassManifest] ( + parent: DStream[T], + foreachFunc: (RDD[T], Time) => Unit) +extends DStream[Unit](parent.ssc) { + + def this(parent: DStream[T], altForeachFunc: (RDD[T]) => Unit) = + this(parent, (rdd: RDD[T], time: Time) => altForeachFunc(rdd)) + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + foreachFunc(rdd, time) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala new file mode 100644 index 0000000000..88aa375289 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -0,0 +1,164 @@ +package spark.streaming + +import spark.SparkContext +import spark.RDD +import spark.BlockRDD +import spark.UnionRDD +import spark.storage.StorageLevel +import spark.streaming._ + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import java.net.InetSocketAddress + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.PathFilter +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + + +class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( + ssc: SparkStreamContext, + directory: Path, + filter: PathFilter = FileInputDStream.defaultPathFilter, + newFilesOnly: Boolean = true) + extends InputDStream[(K, V)](ssc) { + + val fs = directory.getFileSystem(new Configuration()) + var lastModTime: Long = 0 + + override def start() { + if (newFilesOnly) { + lastModTime = System.currentTimeMillis() + } else { + lastModTime = 0 + } + } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val newFilter = new PathFilter() { + var latestModTime = 0L + + def accept(path: Path): Boolean = { + if (!filter.accept(path)) { + return false + } else { + val modTime = fs.getFileStatus(path).getModificationTime() + if (modTime <= lastModTime) { + return false + } + if (modTime > latestModTime) { + latestModTime = modTime + } + return true + } + } + } + + val newFiles = fs.listStatus(directory, newFilter) + logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) + if (newFiles.length > 0) { + lastModTime = newFilter.latestModTime + } + val newRDD = new UnionRDD(ssc.sc, newFiles.map( + file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) + Some(newRDD) + } +} + +object FileInputDStream { + val defaultPathFilter = new PathFilter { + def accept(path: Path): Boolean = { + val file = path.getName() + if (file.startsWith(".") || file.endsWith("_tmp")) { + return false + } else { + return true + } + } + } +} + +/* +class NetworkInputDStream[T: ClassManifest]( + val networkInputName: String, + val addresses: Array[InetSocketAddress], + batchDuration: Time, + ssc: SparkStreamContext) +extends InputDStream[T](networkInputName, batchDuration, ssc) { + + + // TODO(Haoyuan): This is for the performance test. + @transient var rdd: RDD[T] = null + + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Running initial count to cache fake RDD") + rdd = ssc.sc.textFile(SparkContext.inputFile, + SparkContext.idealPartitions).asInstanceOf[RDD[T]] + val fakeCacheLevel = System.getProperty("spark.fake.cache", "") + if (fakeCacheLevel == "MEMORY_ONLY_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { + rdd.persist(StorageLevel.MEMORY_ONLY_2) + } else if (fakeCacheLevel != "") { + logError("Invalid fake cache level: " + fakeCacheLevel) + System.exit(1) + } + rdd.count() + } + + @transient val references = new HashMap[Time,String] + + override def compute(validTime: Time): Option[RDD[T]] = { + if (System.getProperty("spark.fake", "false") == "true") { + logInfo("Returning fake RDD at " + validTime) + return Some(rdd) + } + references.get(validTime) match { + case Some(reference) => + if (reference.startsWith("file") || reference.startsWith("hdfs")) { + logInfo("Reading from file " + reference + " for time " + validTime) + Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) + } else { + logInfo("Getting from BlockManager " + reference + " for time " + validTime) + Some(new BlockRDD(ssc.sc, Array(reference))) + } + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.toString)) + logInfo("Reference added for time " + time + " - " + reference.toString) + } +} + + +class TestInputDStream( + val testInputName: String, + batchDuration: Time, + ssc: SparkStreamContext) +extends InputDStream[String](testInputName, batchDuration, ssc) { + + @transient val references = new HashMap[Time,Array[String]] + + override def compute(validTime: Time): Option[RDD[String]] = { + references.get(validTime) match { + case Some(reference) => + Some(new BlockRDD[String](ssc.sc, reference)) + case None => + throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") + None + } + } + + def setReference(time: Time, reference: AnyRef) { + references += ((time, reference.asInstanceOf[Array[String]])) + } +} +*/ diff --git a/streaming/src/main/scala/spark/streaming/FileInputRDS.scala b/streaming/src/main/scala/spark/streaming/FileInputRDS.scala deleted file mode 100644 index ebd246823d..0000000000 --- a/streaming/src/main/scala/spark/streaming/FileInputRDS.scala +++ /dev/null @@ -1,164 +0,0 @@ -package spark.streaming - -import spark.SparkContext -import spark.RDD -import spark.BlockRDD -import spark.UnionRDD -import spark.storage.StorageLevel -import spark.streaming._ - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.PathFilter -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} - - -class FileInputRDS[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - ssc: SparkStreamContext, - directory: Path, - filter: PathFilter = FileInputRDS.defaultPathFilter, - newFilesOnly: Boolean = true) - extends InputRDS[(K, V)](ssc) { - - val fs = directory.getFileSystem(new Configuration()) - var lastModTime: Long = 0 - - override def start() { - if (newFilesOnly) { - lastModTime = System.currentTimeMillis() - } else { - lastModTime = 0 - } - } - - override def stop() { } - - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - val newFilter = new PathFilter() { - var latestModTime = 0L - - def accept(path: Path): Boolean = { - if (!filter.accept(path)) { - return false - } else { - val modTime = fs.getFileStatus(path).getModificationTime() - if (modTime <= lastModTime) { - return false - } - if (modTime > latestModTime) { - latestModTime = modTime - } - return true - } - } - } - - val newFiles = fs.listStatus(directory, newFilter) - logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) - if (newFiles.length > 0) { - lastModTime = newFilter.latestModTime - } - val newRDD = new UnionRDD(ssc.sc, newFiles.map( - file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) - Some(newRDD) - } -} - -object FileInputRDS { - val defaultPathFilter = new PathFilter { - def accept(path: Path): Boolean = { - val file = path.getName() - if (file.startsWith(".") || file.endsWith("_tmp")) { - return false - } else { - return true - } - } - } -} - -/* -class NetworkInputRDS[T: ClassManifest]( - val networkInputName: String, - val addresses: Array[InetSocketAddress], - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[T](networkInputName, batchDuration, ssc) { - - - // TODO(Haoyuan): This is for the performance test. - @transient var rdd: RDD[T] = null - - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Running initial count to cache fake RDD") - rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[T]] - val fakeCacheLevel = System.getProperty("spark.fake.cache", "") - if (fakeCacheLevel == "MEMORY_ONLY_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel != "") { - logError("Invalid fake cache level: " + fakeCacheLevel) - System.exit(1) - } - rdd.count() - } - - @transient val references = new HashMap[Time,String] - - override def compute(validTime: Time): Option[RDD[T]] = { - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Returning fake RDD at " + validTime) - return Some(rdd) - } - references.get(validTime) match { - case Some(reference) => - if (reference.startsWith("file") || reference.startsWith("hdfs")) { - logInfo("Reading from file " + reference + " for time " + validTime) - Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) - } else { - logInfo("Getting from BlockManager " + reference + " for time " + validTime) - Some(new BlockRDD(ssc.sc, Array(reference))) - } - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } -} - - -class TestInputRDS( - val testInputName: String, - batchDuration: Time, - ssc: SparkStreamContext) -extends InputRDS[String](testInputName, batchDuration, ssc) { - - @transient val references = new HashMap[Time,Array[String]] - - override def compute(validTime: Time): Option[RDD[String]] = { - references.get(validTime) match { - case Some(reference) => - Some(new BlockRDD[String](ssc.sc, reference)) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.asInstanceOf[Array[String]])) - } -} -*/ diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index 1960097216..088cbe4376 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -1,7 +1,6 @@ package spark.streaming -case class Interval (val beginTime: Time, val endTime: Time) { - +case class Interval (beginTime: Time, endTime: Time) { def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) def duration(): Time = endTime - beginTime @@ -33,7 +32,7 @@ case class Interval (val beginTime: Time, val endTime: Time) { this + (endTime - beginTime) } - def isZero() = (beginTime.isZero && endTime.isZero) + def isZero = (beginTime.isZero && endTime.isZero) def toFormattedString = beginTime.toFormattedString + "-" + endTime.toFormattedString @@ -41,7 +40,6 @@ case class Interval (val beginTime: Time, val endTime: Time) { } object Interval { - def zero() = new Interval (Time.zero, Time.zero) def currentInterval(intervalDuration: Time): Interval = { @@ -49,7 +47,6 @@ object Interval { val intervalBegin = time.floor(intervalDuration) Interval(intervalBegin, intervalBegin + intervalDuration) } - } diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala new file mode 100644 index 0000000000..0cf296f21a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -0,0 +1,72 @@ +package spark.streaming + +import scala.collection.mutable.ArrayBuffer +import spark.streaming.SparkStreamContext._ + +class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) +extends Serializable { + + def ssc = stream.ssc + + /* ---------------------------------- */ + /* DStream operations for key-value pairs */ + /* ---------------------------------- */ + + def groupByKey(numPartitions: Int = 0): ShuffledDStream[K, V, ArrayBuffer[V]] = { + def createCombiner(v: V) = ArrayBuffer[V](v) + def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) + def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) + combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledDStream[K, V, V] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) + } + + private def combineByKey[C: ClassManifest]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + numPartitions: Int) : ShuffledDStream[K, V, C] = { + new ShuffledDStream[K, V, C](stream, createCombiner, mergeValue, mergeCombiner, numPartitions) + } + + def groupByKeyAndWindow( + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledDStream[K, V, ArrayBuffer[V]] = { + stream.window(windowTime, slideTime).groupByKey(numPartitions) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int = 0): ShuffledDStream[K, V, V] = { + stream.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) + } + + // This method is the efficient sliding window reduce operation, + // which requires the specification of an inverse reduce function, + // so that new elements introduced in the window can be "added" using + // reduceFunc to the previous window's result and old elements can be + // "subtracted using invReduceFunc. + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int): ReducedWindowedDStream[K, V] = { + + new ReducedWindowedDStream[K, V]( + stream, + ssc.sc.clean(reduceFunc), + ssc.sc.clean(invReduceFunc), + windowTime, + slideTime, + numPartitions) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala b/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala deleted file mode 100644 index 403ae233a5..0000000000 --- a/streaming/src/main/scala/spark/streaming/PairRDSFunctions.scala +++ /dev/null @@ -1,72 +0,0 @@ -package spark.streaming - -import scala.collection.mutable.ArrayBuffer -import spark.streaming.SparkStreamContext._ - -class PairRDSFunctions[K: ClassManifest, V: ClassManifest](rds: RDS[(K,V)]) -extends Serializable { - - def ssc = rds.ssc - - /* ---------------------------------- */ - /* RDS operations for key-value pairs */ - /* ---------------------------------- */ - - def groupByKey(numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - def createCombiner(v: V) = ArrayBuffer[V](v) - def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) - def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) - combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) - } - - private def combineByKey[C: ClassManifest]( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - numPartitions: Int) : ShuffledRDS[K, V, C] = { - new ShuffledRDS[K, V, C](rds, createCombiner, mergeValue, mergeCombiner, numPartitions) - } - - def groupByKeyAndWindow( - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, ArrayBuffer[V]] = { - rds.window(windowTime, slideTime).groupByKey(numPartitions) - } - - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledRDS[K, V, V] = { - rds.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) - } - - // This method is the efficient sliding window reduce operation, - // which requires the specification of an inverse reduce function, - // so that new elements introduced in the window can be "added" using - // reduceFunc to the previous window's result and old elements can be - // "subtracted using invReduceFunc. - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int): ReducedWindowedRDS[K, V] = { - - new ReducedWindowedRDS[K, V]( - rds, - ssc.sc.clean(reduceFunc), - ssc.sc.clean(invReduceFunc), - windowTime, - slideTime, - numPartitions) - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala new file mode 100644 index 0000000000..c78abd1a87 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala @@ -0,0 +1,36 @@ +package spark.streaming + +import spark.RDD +import spark.UnionRDD + +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer + +class QueueInputDStream[T: ClassManifest]( + ssc: SparkStreamContext, + val queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ) extends InputDStream[T](ssc) { + + override def start() { } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val buffer = new ArrayBuffer[RDD[T]]() + if (oneAtATime && queue.size > 0) { + buffer += queue.dequeue() + } else { + buffer ++= queue + } + if (buffer.size > 0) { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } else if (defaultRDD != null) { + Some(defaultRDD) + } else { + None + } + } + +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala b/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala deleted file mode 100644 index 31e6a64e21..0000000000 --- a/streaming/src/main/scala/spark/streaming/QueueInputRDS.scala +++ /dev/null @@ -1,36 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.UnionRDD - -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer - -class QueueInputRDS[T: ClassManifest]( - ssc: SparkStreamContext, - val queue: Queue[RDD[T]], - oneAtATime: Boolean, - defaultRDD: RDD[T] - ) extends InputRDS[T](ssc) { - - override def start() { } - - override def stop() { } - - override def compute(validTime: Time): Option[RDD[T]] = { - val buffer = new ArrayBuffer[RDD[T]]() - if (oneAtATime && queue.size > 0) { - buffer += queue.dequeue() - } else { - buffer ++= queue - } - if (buffer.size > 0) { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) - } else if (defaultRDD != null) { - Some(defaultRDD) - } else { - None - } - } - -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/RDS.scala b/streaming/src/main/scala/spark/streaming/RDS.scala deleted file mode 100644 index fd923929e7..0000000000 --- a/streaming/src/main/scala/spark/streaming/RDS.scala +++ /dev/null @@ -1,484 +0,0 @@ -package spark.streaming - -import spark.streaming.SparkStreamContext._ - -import spark.RDD -import spark.BlockRDD -import spark.UnionRDD -import spark.Logging -import spark.SparkContext -import spark.SparkContext._ -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import java.util.concurrent.ArrayBlockingQueue - -abstract class RDS[T: ClassManifest] (@transient val ssc: SparkStreamContext) -extends Logging with Serializable { - - initLogging() - - /** - * ---------------------------------------------- - * Methods that must be implemented by subclasses - * ---------------------------------------------- - */ - - // Time by which the window slides in this RDS - def slideTime: Time - - // List of parent RDSs on which this RDS depends on - def dependencies: List[RDS[_]] - - // Key method that computes RDD for a valid time - def compute (validTime: Time): Option[RDD[T]] - - /** - * --------------------------------------- - * Other general fields and methods of RDS - * --------------------------------------- - */ - - // Variable to store the RDDs generated earlier in time - @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () - - // Variable to be set to the first time seen by the RDS (effective time zero) - private[streaming] var zeroTime: Time = null - - // Variable to specify storage level - private var storageLevel: StorageLevel = StorageLevel.NONE - - // Checkpoint level and checkpoint interval - private var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - private var checkpointInterval: Time = null - - // Change this RDD's storage level - def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): RDS[T] = { - if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { - // TODO: not sure this is necessary for RDSes - throw new UnsupportedOperationException( - "Cannot change storage level of an RDS after it was already assigned a level") - } - this.storageLevel = storageLevel - this.checkpointLevel = checkpointLevel - this.checkpointInterval = checkpointInterval - this - } - - // Set caching level for the RDDs created by this RDS - def persist(newLevel: StorageLevel): RDS[T] = persist(newLevel, StorageLevel.NONE, null) - - def persist(): RDS[T] = persist(StorageLevel.MEMORY_ONLY_DESER) - - // Turn on the default caching level for this RDD - def cache(): RDS[T] = persist() - - def isInitialized = (zeroTime != null) - - /** - * This method initializes the RDS by setting the "zero" time, based on which - * the validity of future times is calculated. This method also recursively initializes - * its parent RDSs. - */ - def initialize(time: Time) { - if (zeroTime == null) { - zeroTime = time - } - logInfo(this + " initialized") - dependencies.foreach(_.initialize(zeroTime)) - } - - /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ - private def isTimeValid (time: Time): Boolean = { - if (!isInitialized) - throw new Exception (this.toString + " has not been initialized") - if ((time - zeroTime).isMultipleOf(slideTime)) { - true - } else { - false - } - } - - /** - * This method either retrieves a precomputed RDD of this RDS, - * or computes the RDD (if the time is valid) - */ - def getOrCompute(time: Time): Option[RDD[T]] = { - // If this RDS was not initialized (i.e., zeroTime not set), then do it - // If RDD was already generated, then retrieve it from HashMap - generatedRDDs.get(time) match { - - // If an RDD was already generated and is being reused, then - // probably all RDDs in this RDS will be reused and hence should be cached - case Some(oldRDD) => Some(oldRDD) - - // if RDD was not generated, and if the time is valid - // (based on sliding time of this RDS), then generate the RDD - case None => - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } - generatedRDDs.put(time.copy(), newRDD) - Some(newRDD) - case None => - None - } - } else { - None - } - } - } - - /** - * This method generates a SparkStream job for the given time - * and may require to be overriden by subclasses - */ - def generateJob(time: Time): Option[Job] = { - getOrCompute(time) match { - case Some(rdd) => { - val jobFunc = () => { - val emptyFunc = { (iterator: Iterator[T]) => {} } - ssc.sc.runJob(rdd, emptyFunc) - } - Some(new Job(time, jobFunc)) - } - case None => None - } - } - - /** - * -------------- - * RDS operations - * -------------- - */ - - def map[U: ClassManifest](mapFunc: T => U) = new MappedRDS(this, ssc.sc.clean(mapFunc)) - - def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = - new FlatMappedRDS(this, ssc.sc.clean(flatMapFunc)) - - def filter(filterFunc: T => Boolean) = new FilteredRDS(this, filterFunc) - - def glom() = new GlommedRDS(this) - - def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = - new MapPartitionedRDS(this, ssc.sc.clean(mapPartFunc)) - - def reduce(reduceFunc: (T, T) => T) = this.map(x => (1, x)).reduceByKey(reduceFunc, 1).map(_._2) - - def count() = this.map(_ => 1).reduce(_ + _) - - def collect() = this.map(x => (1, x)).groupByKey(1).map(_._2) - - def foreach(foreachFunc: T => Unit) = { - val newrds = new PerElementForEachRDS(this, ssc.sc.clean(foreachFunc)) - ssc.registerOutputStream(newrds) - newrds - } - - def foreachRDD(foreachFunc: RDD[T] => Unit) = { - val newrds = new PerRDDForEachRDS(this, ssc.sc.clean(foreachFunc)) - ssc.registerOutputStream(newrds) - newrds - } - - private[streaming] def toQueue() = { - val queue = new ArrayBlockingQueue[RDD[T]](10000) - this.foreachRDD(rdd => { - println("Added RDD " + rdd.id) - queue.add(rdd) - }) - queue - } - - def print() = { - def foreachFunc = (rdd: RDD[T], time: Time) => { - val first11 = rdd.take(11) - println ("-------------------------------------------") - println ("Time: " + time) - println ("-------------------------------------------") - first11.take(10).foreach(println) - if (first11.size > 10) println("...") - println() - } - val newrds = new PerRDDForEachRDS(this, ssc.sc.clean(foreachFunc)) - ssc.registerOutputStream(newrds) - newrds - } - - def window(windowTime: Time, slideTime: Time) = new WindowedRDS(this, windowTime, slideTime) - - def batch(batchTime: Time) = window(batchTime, batchTime) - - def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time) = - this.window(windowTime, slideTime).reduce(reduceFunc) - - def reduceByWindow( - reduceFunc: (T, T) => T, - invReduceFunc: (T, T) => T, - windowTime: Time, - slideTime: Time) = { - this.map(x => (1, x)) - .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1) - .map(_._2) - } - - def countByWindow(windowTime: Time, slideTime: Time) = { - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime) - } - - def union(that: RDS[T]) = new UnifiedRDS(Array(this, that)) - - def register() = ssc.registerOutputStream(this) -} - - -abstract class InputRDS[T: ClassManifest] ( - ssc: SparkStreamContext) -extends RDS[T](ssc) { - - override def dependencies = List() - - override def slideTime = ssc.batchDuration - - def start() - - def stop() -} - - -/** - * TODO - */ - -class MappedRDS[T: ClassManifest, U: ClassManifest] ( - parent: RDS[T], - mapFunc: T => U) -extends RDS[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.map[U](mapFunc)) - } -} - - -/** - * TODO - */ - -class FlatMappedRDS[T: ClassManifest, U: ClassManifest]( - parent: RDS[T], - flatMapFunc: T => Traversable[U]) -extends RDS[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) - } -} - - -/** - * TODO - */ - -class FilteredRDS[T: ClassManifest](parent: RDS[T], filterFunc: T => Boolean) -extends RDS[T](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - parent.getOrCompute(validTime).map(_.filter(filterFunc)) - } -} - - -/** - * TODO - */ - -class MapPartitionedRDS[T: ClassManifest, U: ClassManifest]( - parent: RDS[T], - mapPartFunc: Iterator[T] => Iterator[U]) -extends RDS[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc)) - } -} - - -/** - * TODO - */ - -class GlommedRDS[T: ClassManifest](parent: RDS[T]) extends RDS[Array[T]](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Array[T]]] = { - parent.getOrCompute(validTime).map(_.glom()) - } -} - - -/** - * TODO - */ - -class ShuffledRDS[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - parent: RDS[(K,V)], - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - numPartitions: Int) - extends RDS [(K,C)] (parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K,C)]] = { - parent.getOrCompute(validTime) match { - case Some(rdd) => - val newrdd = { - if (numPartitions > 0) { - rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, numPartitions) - } else { - rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner) - } - } - Some(newrdd) - case None => None - } - } -} - - -/** - * TODO - */ - -class UnifiedRDS[T: ClassManifest](parents: Array[RDS[T]]) -extends RDS[T](parents(0).ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different SparkStreamContexts") - } - - if (parents.map(_.slideTime).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideTime: Time = parents(0).slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { - case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) - if (rdds.size > 0) { - Some(new UnionRDD(ssc.sc, rdds)) - } else { - None - } - } -} - - -/** - * TODO - */ - -class PerElementForEachRDS[T: ClassManifest] ( - parent: RDS[T], - foreachFunc: T => Unit) -extends RDS[Unit](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - val sparkJobFunc = { - (iterator: Iterator[T]) => iterator.foreach(foreachFunc) - } - ssc.sc.runJob(rdd, sparkJobFunc) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} - - -/** - * TODO - */ - -class PerRDDForEachRDS[T: ClassManifest] ( - parent: RDS[T], - foreachFunc: (RDD[T], Time) => Unit) -extends RDS[Unit](parent.ssc) { - - def this(parent: RDS[T], altForeachFunc: (RDD[T]) => Unit) = - this(parent, (rdd: RDD[T], time: Time) => altForeachFunc(rdd)) - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - foreachFunc(rdd, time) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala new file mode 100644 index 0000000000..11fa4e5443 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -0,0 +1,218 @@ +package spark.streaming + +import spark.streaming.SparkStreamContext._ + +import spark.RDD +import spark.UnionRDD +import spark.CoGroupedRDD +import spark.HashPartitioner +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer + +class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( + parent: DStream[(K, V)], + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + _windowTime: Time, + _slideTime: Time, + numPartitions: Int) +extends DStream[(K,V)](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + val reducedStream = parent.reduceByKey(reduceFunc, numPartitions) + val allowPartialWindows = true + //reducedStream.persist(StorageLevel.MEMORY_ONLY_DESER_2) + + override def dependencies = List(reducedStream) + + def windowTime: Time = _windowTime + + override def slideTime: Time = _slideTime + + override def persist( + storageLevel: StorageLevel, + checkpointLevel: StorageLevel, + checkpointInterval: Time): DStream[(K,V)] = { + super.persist(storageLevel, checkpointLevel, checkpointInterval) + reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval) + } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + + + // Notation: + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old time steps new time steps + // + def getAdjustedWindow(endTime: Time, windowTime: Time): Interval = { + val beginTime = + if (allowPartialWindows && endTime - windowTime < parent.zeroTime) { + parent.zeroTime + } else { + endTime - windowTime + } + Interval(beginTime, endTime) + } + + val currentTime = validTime.copy + val currentWindow = getAdjustedWindow(currentTime, windowTime) + val previousWindow = getAdjustedWindow(currentTime - slideTime, windowTime) + + logInfo("Current window = " + currentWindow) + logInfo("Previous window = " + previousWindow) + logInfo("Parent.zeroTime = " + parent.zeroTime) + + if (allowPartialWindows) { + if (currentTime - slideTime == parent.zeroTime) { + reducedStream.getOrCompute(currentTime) match { + case Some(rdd) => return Some(rdd) + case None => throw new Exception("Could not get first reduced RDD for time " + currentTime) + } + } + } else { + if (previousWindow.beginTime < parent.zeroTime) { + if (currentWindow.beginTime < parent.zeroTime) { + return None + } else { + // If this is the first feasible window, then generate reduced value in the naive manner + val reducedRDDs = new ArrayBuffer[RDD[(K, V)]]() + var t = currentWindow.endTime + while (t > currentWindow.beginTime) { + reducedStream.getOrCompute(t) match { + case Some(rdd) => reducedRDDs += rdd + case None => throw new Exception("Could not get reduced RDD for time " + t) + } + t -= reducedStream.slideTime + } + if (reducedRDDs.size == 0) { + throw new Exception("Could not generate the first RDD for time " + validTime) + } + return Some(new UnionRDD(ssc.sc, reducedRDDs).reduceByKey(reduceFunc, numPartitions)) + } + } + } + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = getOrCompute(previousWindow.endTime) match { + case Some(rdd) => rdd.asInstanceOf[RDD[(_, _)]] + case None => throw new Exception("Could not get previous RDD for time " + previousWindow.endTime) + } + + val oldRDDs = new ArrayBuffer[RDD[(_, _)]]() + val newRDDs = new ArrayBuffer[RDD[(_, _)]]() + + // Get the RDDs of the reduced values in "old time steps" + var t = currentWindow.beginTime + while (t > previousWindow.beginTime) { + reducedStream.getOrCompute(t) match { + case Some(rdd) => oldRDDs += rdd.asInstanceOf[RDD[(_, _)]] + case None => throw new Exception("Could not get old reduced RDD for time " + t) + } + t -= reducedStream.slideTime + } + + // Get the RDDs of the reduced values in "new time steps" + t = currentWindow.endTime + while (t > previousWindow.endTime) { + reducedStream.getOrCompute(t) match { + case Some(rdd) => newRDDs += rdd.asInstanceOf[RDD[(_, _)]] + case None => throw new Exception("Could not get new reduced RDD for time " + t) + } + t -= reducedStream.slideTime + } + + val partitioner = new HashPartitioner(numPartitions) + val allRDDs = new ArrayBuffer[RDD[(_, _)]]() + allRDDs += previousWindowRDD + allRDDs ++= oldRDDs + allRDDs ++= newRDDs + + + val numOldRDDs = oldRDDs.size + val numNewRDDs = newRDDs.size + logInfo("Generated numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) + logInfo("Generating CoGroupedRDD with " + allRDDs.size + " RDDs") + val newRDD = new CoGroupedRDD[K](allRDDs.toSeq, partitioner).asInstanceOf[RDD[(K,Seq[Seq[V]])]].map(x => { + val (key, value) = x + logDebug("value.size = " + value.size + ", numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) + if (value.size != 1 + numOldRDDs + numNewRDDs) { + throw new Exception("Number of groups not odd!") + } + + // old values = reduced values of the "old time steps" that are eliminated from current window + // new values = reduced values of the "new time steps" that are introduced to the current window + // previous value = reduced value of the previous window + + /*val numOldValues = (value.size - 1) / 2*/ + // Getting reduced values "old time steps" + val oldValues = + (0 until numOldRDDs).map(i => value(1 + i)).filter(_.size > 0).map(x => x(0)) + // Getting reduced values "new time steps" + val newValues = + (0 until numNewRDDs).map(i => value(1 + numOldRDDs + i)).filter(_.size > 0).map(x => x(0)) + + // If reduced value for the key does not exist in previous window, it should not exist in "old time steps" + if (value(0).size == 0 && oldValues.size != 0) { + throw new Exception("Unexpected: Key exists in old reduced values but not in previous reduced values") + } + + // For the key, at least one of "old time steps", "new time steps" and previous window should have reduced values + if (value(0).size == 0 && oldValues.size == 0 && newValues.size == 0) { + throw new Exception("Unexpected: Key does not exist in any of old, new, or previour reduced values") + } + + // Logic to generate the final reduced value for current window: + // + // If previous window did not have reduced value for the key + // Then, return reduced value of "new time steps" as the final value + // Else, reduced value exists in previous window + // If "old" time steps did not have reduced value for the key + // Then, reduce previous window's reduced value with that of "new time steps" for final value + // Else, reduced values exists in "old time steps" + // If "new values" did not have reduced value for the key + // Then, inverse-reduce "old values" from previous window's reduced value for final value + // Else, all 3 values exist, combine all of them together + // + logDebug("# old values = " + oldValues.size + ", # new values = " + newValues) + val finalValue = { + if (value(0).size == 0) { + newValues.reduce(reduceFunc) + } else { + val prevValue = value(0)(0) + logDebug("prev value = " + prevValue) + if (oldValues.size == 0) { + // assuming newValue.size > 0 (all 3 cannot be zero, as checked earlier) + val temp = newValues.reduce(reduceFunc) + reduceFunc(prevValue, temp) + } else if (newValues.size == 0) { + invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) + } else { + val tempValue = invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) + reduceFunc(tempValue, newValues.reduce(reduceFunc)) + } + } + } + (key, finalValue) + }) + //newRDD.persist(StorageLevel.MEMORY_ONLY_DESER_2) + Some(newRDD) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedRDS.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedRDS.scala deleted file mode 100644 index dd1f474657..0000000000 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedRDS.scala +++ /dev/null @@ -1,218 +0,0 @@ -package spark.streaming - -import spark.streaming.SparkStreamContext._ - -import spark.RDD -import spark.UnionRDD -import spark.CoGroupedRDD -import spark.HashPartitioner -import spark.SparkContext._ -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer - -class ReducedWindowedRDS[K: ClassManifest, V: ClassManifest]( - parent: RDS[(K, V)], - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - _windowTime: Time, - _slideTime: Time, - numPartitions: Int) -extends RDS[(K,V)](parent.ssc) { - - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of ReducedWindowedRDS (" + _slideTime + ") " + - "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") - - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of ReducedWindowedRDS (" + _slideTime + ") " + - "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") - - val reducedRDS = parent.reduceByKey(reduceFunc, numPartitions) - val allowPartialWindows = true - //reducedRDS.persist(StorageLevel.MEMORY_ONLY_DESER_2) - - override def dependencies = List(reducedRDS) - - def windowTime: Time = _windowTime - - override def slideTime: Time = _slideTime - - override def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): RDS[(K,V)] = { - super.persist(storageLevel, checkpointLevel, checkpointInterval) - reducedRDS.persist(storageLevel, checkpointLevel, checkpointInterval) - } - - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - - - // Notation: - // _____________________________ - // | previous window _________|___________________ - // |___________________| current window | --------------> Time - // |_____________________________| - // - // |________ _________| |________ _________| - // | | - // V V - // old time steps new time steps - // - def getAdjustedWindow(endTime: Time, windowTime: Time): Interval = { - val beginTime = - if (allowPartialWindows && endTime - windowTime < parent.zeroTime) { - parent.zeroTime - } else { - endTime - windowTime - } - Interval(beginTime, endTime) - } - - val currentTime = validTime.copy - val currentWindow = getAdjustedWindow(currentTime, windowTime) - val previousWindow = getAdjustedWindow(currentTime - slideTime, windowTime) - - logInfo("Current window = " + currentWindow) - logInfo("Previous window = " + previousWindow) - logInfo("Parent.zeroTime = " + parent.zeroTime) - - if (allowPartialWindows) { - if (currentTime - slideTime == parent.zeroTime) { - reducedRDS.getOrCompute(currentTime) match { - case Some(rdd) => return Some(rdd) - case None => throw new Exception("Could not get first reduced RDD for time " + currentTime) - } - } - } else { - if (previousWindow.beginTime < parent.zeroTime) { - if (currentWindow.beginTime < parent.zeroTime) { - return None - } else { - // If this is the first feasible window, then generate reduced value in the naive manner - val reducedRDDs = new ArrayBuffer[RDD[(K, V)]]() - var t = currentWindow.endTime - while (t > currentWindow.beginTime) { - reducedRDS.getOrCompute(t) match { - case Some(rdd) => reducedRDDs += rdd - case None => throw new Exception("Could not get reduced RDD for time " + t) - } - t -= reducedRDS.slideTime - } - if (reducedRDDs.size == 0) { - throw new Exception("Could not generate the first RDD for time " + validTime) - } - return Some(new UnionRDD(ssc.sc, reducedRDDs).reduceByKey(reduceFunc, numPartitions)) - } - } - } - - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = getOrCompute(previousWindow.endTime) match { - case Some(rdd) => rdd.asInstanceOf[RDD[(_, _)]] - case None => throw new Exception("Could not get previous RDD for time " + previousWindow.endTime) - } - - val oldRDDs = new ArrayBuffer[RDD[(_, _)]]() - val newRDDs = new ArrayBuffer[RDD[(_, _)]]() - - // Get the RDDs of the reduced values in "old time steps" - var t = currentWindow.beginTime - while (t > previousWindow.beginTime) { - reducedRDS.getOrCompute(t) match { - case Some(rdd) => oldRDDs += rdd.asInstanceOf[RDD[(_, _)]] - case None => throw new Exception("Could not get old reduced RDD for time " + t) - } - t -= reducedRDS.slideTime - } - - // Get the RDDs of the reduced values in "new time steps" - t = currentWindow.endTime - while (t > previousWindow.endTime) { - reducedRDS.getOrCompute(t) match { - case Some(rdd) => newRDDs += rdd.asInstanceOf[RDD[(_, _)]] - case None => throw new Exception("Could not get new reduced RDD for time " + t) - } - t -= reducedRDS.slideTime - } - - val partitioner = new HashPartitioner(numPartitions) - val allRDDs = new ArrayBuffer[RDD[(_, _)]]() - allRDDs += previousWindowRDD - allRDDs ++= oldRDDs - allRDDs ++= newRDDs - - - val numOldRDDs = oldRDDs.size - val numNewRDDs = newRDDs.size - logInfo("Generated numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) - logInfo("Generating CoGroupedRDD with " + allRDDs.size + " RDDs") - val newRDD = new CoGroupedRDD[K](allRDDs.toSeq, partitioner).asInstanceOf[RDD[(K,Seq[Seq[V]])]].map(x => { - val (key, value) = x - logDebug("value.size = " + value.size + ", numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) - if (value.size != 1 + numOldRDDs + numNewRDDs) { - throw new Exception("Number of groups not odd!") - } - - // old values = reduced values of the "old time steps" that are eliminated from current window - // new values = reduced values of the "new time steps" that are introduced to the current window - // previous value = reduced value of the previous window - - /*val numOldValues = (value.size - 1) / 2*/ - // Getting reduced values "old time steps" - val oldValues = - (0 until numOldRDDs).map(i => value(1 + i)).filter(_.size > 0).map(x => x(0)) - // Getting reduced values "new time steps" - val newValues = - (0 until numNewRDDs).map(i => value(1 + numOldRDDs + i)).filter(_.size > 0).map(x => x(0)) - - // If reduced value for the key does not exist in previous window, it should not exist in "old time steps" - if (value(0).size == 0 && oldValues.size != 0) { - throw new Exception("Unexpected: Key exists in old reduced values but not in previous reduced values") - } - - // For the key, at least one of "old time steps", "new time steps" and previous window should have reduced values - if (value(0).size == 0 && oldValues.size == 0 && newValues.size == 0) { - throw new Exception("Unexpected: Key does not exist in any of old, new, or previour reduced values") - } - - // Logic to generate the final reduced value for current window: - // - // If previous window did not have reduced value for the key - // Then, return reduced value of "new time steps" as the final value - // Else, reduced value exists in previous window - // If "old" time steps did not have reduced value for the key - // Then, reduce previous window's reduced value with that of "new time steps" for final value - // Else, reduced values exists in "old time steps" - // If "new values" did not have reduced value for the key - // Then, inverse-reduce "old values" from previous window's reduced value for final value - // Else, all 3 values exist, combine all of them together - // - logDebug("# old values = " + oldValues.size + ", # new values = " + newValues) - val finalValue = { - if (value(0).size == 0) { - newValues.reduce(reduceFunc) - } else { - val prevValue = value(0)(0) - logDebug("prev value = " + prevValue) - if (oldValues.size == 0) { - // assuming newValue.size > 0 (all 3 cannot be zero, as checked earlier) - val temp = newValues.reduce(reduceFunc) - reduceFunc(prevValue, temp) - } else if (newValues.size == 0) { - invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) - } else { - val tempValue = invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) - reduceFunc(tempValue, newValues.reduce(reduceFunc)) - } - } - } - (key, finalValue) - }) - //newRDD.persist(StorageLevel.MEMORY_ONLY_DESER_2) - Some(newRDD) - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 83f874e550..fff4924b4c 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -9,12 +9,11 @@ import scala.collection.mutable.HashMap sealed trait SchedulerMessage case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage -case class Test extends SchedulerMessage class Scheduler( ssc: SparkStreamContext, - inputRDSs: Array[InputRDS[_]], - outputRDSs: Array[RDS[_]]) + inputStreams: Array[InputDStream[_]], + outputStreams: Array[DStream[_]]) extends Logging { initLogging() @@ -26,21 +25,21 @@ extends Logging { def start() { val zeroTime = Time(timer.start()) - outputRDSs.foreach(_.initialize(zeroTime)) - inputRDSs.par.foreach(_.start()) + outputStreams.foreach(_.initialize(zeroTime)) + inputStreams.par.foreach(_.start()) logInfo("Scheduler started") } def stop() { timer.stop() - inputRDSs.par.foreach(_.stop()) + inputStreams.par.foreach(_.stop()) logInfo("Scheduler stopped") } def generateRDDs (time: Time) { logInfo("Generating RDDs for time " + time) - outputRDSs.foreach(outputRDS => { - outputRDS.generateJob(time) match { + outputStreams.foreach(outputStream => { + outputStream.generateJob(time) match { case Some(job) => submitJob(job) case None => } diff --git a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala index d32f6d588c..2bec1091c0 100644 --- a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala +++ b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala @@ -31,8 +31,8 @@ class SparkStreamContext ( val sc = new SparkContext(master, frameworkName, sparkHome, jars) val env = SparkEnv.get - val inputRDSs = new ArrayBuffer[InputRDS[_]]() - val outputRDSs = new ArrayBuffer[RDS[_]]() + val inputStreams = new ArrayBuffer[InputDStream[_]]() + val outputStreams = new ArrayBuffer[DStream[_]]() var batchDuration: Time = null var scheduler: Scheduler = null @@ -48,17 +48,17 @@ class SparkStreamContext ( def createNetworkStream[T: ClassManifest]( name: String, addresses: Array[InetSocketAddress], - batchDuration: Time): RDS[T] = { + batchDuration: Time): DStream[T] = { - val inputRDS = new NetworkInputRDS[T](this, addresses) - inputRDSs += inputRDS - inputRDS + val inputStream = new NetworkinputStream[T](this, addresses) + inputStreams += inputStream + inputStream } def createNetworkStream[T: ClassManifest]( name: String, addresses: Array[String], - batchDuration: Long): RDS[T] = { + batchDuration: Long): DStream[T] = { def stringToInetSocketAddress (str: String): InetSocketAddress = { val parts = str.split(":") @@ -83,13 +83,13 @@ class SparkStreamContext ( K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest - ](directory: String): RDS[(K, V)] = { - val inputRDS = new FileInputRDS[K, V, F](this, new Path(directory)) - inputRDSs += inputRDS - inputRDS + ](directory: String): DStream[(K, V)] = { + val inputStream = new FileInputDStream[K, V, F](this, new Path(directory)) + inputStreams += inputStream + inputStream } - def createTextFileStream(directory: String): RDS[String] = { + def createTextFileStream(directory: String): DStream[String] = { createFileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } @@ -101,26 +101,26 @@ class SparkStreamContext ( queue: Queue[RDD[T]], oneAtATime: Boolean = true, defaultRDD: RDD[T] = null - ): RDS[T] = { - val inputRDS = new QueueInputRDS(this, queue, oneAtATime, defaultRDD) - inputRDSs += inputRDS - inputRDS + ): DStream[T] = { + val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) + inputStreams += inputStream + inputStream } - def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): RDS[T] = { + def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): DStream[T] = { val queue = new Queue[RDD[T]] - val inputRDS = createQueueStream(queue, true, null) + val inputStream = createQueueStream(queue, true, null) queue ++= iterator - inputRDS + inputStream } /** - * This function registers a RDS as an output stream that will be + * This function registers a DStream as an output stream that will be * computed every interval. */ - def registerOutputStream (outputRDS: RDS[_]) { - outputRDSs += outputRDS + def registerOutputStream (outputStream: DStream[_]) { + outputStreams += outputStream } /** @@ -133,11 +133,11 @@ class SparkStreamContext ( if (batchDuration < Milliseconds(100)) { logWarning("Batch duration of " + batchDuration + " is very low") } - if (inputRDSs.size == 0) { - throw new Exception("No input RDSes created, so nothing to take input from") + if (inputStreams.size == 0) { + throw new Exception("No input streams created, so nothing to take input from") } - if (outputRDSs.size == 0) { - throw new Exception("No output RDSes registered, so nothing to execute") + if (outputStreams.size == 0) { + throw new Exception("No output streams registered, so nothing to execute") } } @@ -147,7 +147,7 @@ class SparkStreamContext ( */ def start() { verify() - scheduler = new Scheduler(this, inputRDSs.toArray, outputRDSs.toArray) + scheduler = new Scheduler(this, inputStreams.toArray, outputStreams.toArray) scheduler.start() } @@ -168,6 +168,6 @@ class SparkStreamContext ( object SparkStreamContext { - implicit def rdsToPairRdsFunctions [K: ClassManifest, V: ClassManifest] (rds: RDS[(K,V)]) = - new PairRDSFunctions (rds) + implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = + new PairDStreamFunctions(stream) } diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index c4573137ae..5c476f02c3 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -50,11 +50,11 @@ class Time(private var millis: Long) { def isZero = (this.millis == 0) - override def toString() = (millis.toString + " ms") + override def toString = (millis.toString + " ms") - def toFormattedString() = millis.toString + def toFormattedString = millis.toString - def milliseconds() = millis + def milliseconds = millis } object Time { diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala new file mode 100644 index 0000000000..9a6617a1ee --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -0,0 +1,68 @@ +package spark.streaming + +import spark.streaming.SparkStreamContext._ + +import spark.RDD +import spark.UnionRDD +import spark.SparkContext._ + +import scala.collection.mutable.ArrayBuffer + +class WindowedDStream[T: ClassManifest]( + parent: DStream[T], + _windowTime: Time, + _slideTime: Time) + extends DStream[T](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + val allowPartialWindows = true + + override def dependencies = List(parent) + + def windowTime: Time = _windowTime + + override def slideTime: Time = _slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val parentRDDs = new ArrayBuffer[RDD[T]]() + val windowEndTime = validTime.copy() + val windowStartTime = if (allowPartialWindows && windowEndTime - windowTime < parent.zeroTime) { + parent.zeroTime + } else { + windowEndTime - windowTime + } + + logInfo("Window = " + windowStartTime + " - " + windowEndTime) + logInfo("Parent.zeroTime = " + parent.zeroTime) + + if (windowStartTime >= parent.zeroTime) { + // Walk back through time, from the 'windowEndTime' to 'windowStartTime' + // and get all parent RDDs from the parent DStream + var t = windowEndTime + while (t > windowStartTime) { + parent.getOrCompute(t) match { + case Some(rdd) => parentRDDs += rdd + case None => throw new Exception("Could not generate parent RDD for time " + t) + } + t -= parent.slideTime + } + } + + // Do a union of all parent RDDs to generate the new RDD + if (parentRDDs.size > 0) { + Some(new UnionRDD(ssc.sc, parentRDDs)) + } else { + None + } + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/WindowedRDS.scala b/streaming/src/main/scala/spark/streaming/WindowedRDS.scala deleted file mode 100644 index 812a982301..0000000000 --- a/streaming/src/main/scala/spark/streaming/WindowedRDS.scala +++ /dev/null @@ -1,68 +0,0 @@ -package spark.streaming - -import spark.streaming.SparkStreamContext._ - -import spark.RDD -import spark.UnionRDD -import spark.SparkContext._ - -import scala.collection.mutable.ArrayBuffer - -class WindowedRDS[T: ClassManifest]( - parent: RDS[T], - _windowTime: Time, - _slideTime: Time) - extends RDS[T](parent.ssc) { - - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of WindowedRDS (" + _slideTime + ") " + - "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") - - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of WindowedRDS (" + _slideTime + ") " + - "must be multiple of the slide duration of parent RDS (" + parent.slideTime + ")") - - val allowPartialWindows = true - - override def dependencies = List(parent) - - def windowTime: Time = _windowTime - - override def slideTime: Time = _slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val parentRDDs = new ArrayBuffer[RDD[T]]() - val windowEndTime = validTime.copy() - val windowStartTime = if (allowPartialWindows && windowEndTime - windowTime < parent.zeroTime) { - parent.zeroTime - } else { - windowEndTime - windowTime - } - - logInfo("Window = " + windowStartTime + " - " + windowEndTime) - logInfo("Parent.zeroTime = " + parent.zeroTime) - - if (windowStartTime >= parent.zeroTime) { - // Walk back through time, from the 'windowEndTime' to 'windowStartTime' - // and get all parent RDDs from the parent RDS - var t = windowEndTime - while (t > windowStartTime) { - parent.getOrCompute(t) match { - case Some(rdd) => parentRDDs += rdd - case None => throw new Exception("Could not generate parent RDD for time " + t) - } - t -= parent.slideTime - } - } - - // Do a union of all parent RDDs to generate the new RDD - if (parentRDDs.size > 0) { - Some(new UnionRDD(ssc.sc, parentRDDs)) - } else { - None - } - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala index d56fdcdf29..669f575240 100644 --- a/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala @@ -20,10 +20,10 @@ object ExampleOne { ssc.setBatchDuration(Seconds(1)) // Create the queue through which RDDs can be pushed to - // a QueueInputRDS + // a QueueInputDStream val rddQueue = new SynchronizedQueue[RDD[Int]]() - // Create the QueueInputRDs and use it do some processing + // Create the QueueInputDStream and use it do some processing val inputStream = ssc.createQueueStream(rddQueue) val mappedStream = inputStream.map(x => (x % 10, 1)) val reducedStream = mappedStream.reduceByKey(_ + _) diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala index 4b8f6d609d..be47e47a5a 100644 --- a/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala @@ -24,12 +24,12 @@ object ExampleTwo { if (fs.exists(directory)) throw new Exception("This directory already exists") fs.mkdirs(directory) - // Create the FileInputRDS on the directory and use the + // Create the FileInputDStream on the directory and use the // stream to count words in new files created - val inputRDS = ssc.createTextFileStream(directory.toString) - val wordsRDS = inputRDS.flatMap(_.split(" ")) - val wordCountsRDS = wordsRDS.map(x => (x, 1)).reduceByKey(_ + _) - wordCountsRDS.print + val inputStream = ssc.createTextFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() ssc.start() // Creating new files in the directory diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala index a155630151..ba7bc63d6a 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala @@ -14,12 +14,12 @@ object WordCount { val ssc = new SparkStreamContext(args(0), "ExampleTwo") ssc.setBatchDuration(Seconds(2)) - // Create the FileInputRDS on the directory and use the + // Create the FileInputDStream on the directory and use the // stream to count words in new files created - val inputRDS = ssc.createTextFileStream(args(1)) - val wordsRDS = inputRDS.flatMap(_.split(" ")) - val wordCountsRDS = wordsRDS.map(x => (x, 1)).reduceByKey(_ + _) - wordCountsRDS.print() + val lines = ssc.createTextFileStream(args(1)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() ssc.start() } } diff --git a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala index 9925b1d07c..9fb1924798 100644 --- a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala +++ b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala @@ -23,7 +23,7 @@ object Receiver { count += 28 } } catch { - case e: Exception => e.printStackTrace + case e: Exception => e.printStackTrace() } val timeTaken = System.currentTimeMillis - time val tput = (count / 1024.0) / (timeTaken / 1000.0) diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala new file mode 100644 index 0000000000..ce7c3d2e2b --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala @@ -0,0 +1,65 @@ +package spark.streaming + +import spark.{Logging, RDD} + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.SynchronizedQueue + +class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { + + var ssc: SparkStreamContext = null + val batchDurationMillis = 1000 + + def testOp[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]]) { + try { + ssc = new SparkStreamContext("local", "test") + ssc.setBatchDuration(Milliseconds(batchDurationMillis)) + + val inputStream = ssc.createQueueStream(input.map(ssc.sc.makeRDD(_, 2)).toIterator) + val outputStream = operation(inputStream) + val outputQueue = outputStream.toQueue + + ssc.start() + Thread.sleep(batchDurationMillis * input.size) + + val output = new ArrayBuffer[Seq[V]]() + while(outputQueue.size > 0) { + val rdd = outputQueue.take() + logInfo("Collecting RDD " + rdd.id + ", " + rdd.getClass.getSimpleName + ", " + rdd.splits.size) + output += (rdd.collect()) + } + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i).toList === expectedOutput(i).toList) + } + } finally { + ssc.stop() + } + } + + test("basic operations") { + val inputData = Array(1 to 4, 5 to 8, 9 to 12) + + // map + testOp(inputData, (r: DStream[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + + // flatMap + testOp(inputData, (r: DStream[Int]) => r.flatMap(x => Array(x, x * 2)), + inputData.map(_.flatMap(x => Array(x, x * 2))) + ) + } +} + +object DStreamSuite { + def main(args: Array[String]) { + val r = new DStreamSuite() + val inputData = Array(1 to 4, 5 to 8, 9 to 12) + r.testOp(inputData, (r: DStream[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + } +} \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/RDSSuite.scala b/streaming/src/test/scala/spark/streaming/RDSSuite.scala deleted file mode 100644 index f51ea50a5d..0000000000 --- a/streaming/src/test/scala/spark/streaming/RDSSuite.scala +++ /dev/null @@ -1,65 +0,0 @@ -package spark.streaming - -import spark.RDD - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.SynchronizedQueue - -class RDSSuite extends FunSuite with BeforeAndAfter { - - var ssc: SparkStreamContext = null - val batchDurationMillis = 1000 - - def testOp[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: RDS[U] => RDS[V], - expectedOutput: Seq[Seq[V]]) = { - try { - ssc = new SparkStreamContext("local", "test") - ssc.setBatchDuration(Milliseconds(batchDurationMillis)) - - val inputStream = ssc.createQueueStream(input.map(ssc.sc.makeRDD(_, 2)).toIterator) - val outputStream = operation(inputStream) - val outputQueue = outputStream.toQueue - - ssc.start() - Thread.sleep(batchDurationMillis * input.size) - - val output = new ArrayBuffer[Seq[V]]() - while(outputQueue.size > 0) { - val rdd = outputQueue.take() - println("Collecting RDD " + rdd.id + ", " + rdd.getClass().getSimpleName() + ", " + rdd.splits.size) - output += (rdd.collect()) - } - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i).toList === expectedOutput(i).toList) - } - } finally { - ssc.stop() - } - } - - test("basic operations") { - val inputData = Array(1 to 4, 5 to 8, 9 to 12) - - // map - testOp(inputData, (r: RDS[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) - - // flatMap - testOp(inputData, (r: RDS[Int]) => r.flatMap(x => Array(x, x * 2)), - inputData.map(_.flatMap(x => Array(x, x * 2))) - ) - } -} - -object RDSSuite { - def main(args: Array[String]) { - val r = new RDSSuite() - val inputData = Array(1 to 4, 5 to 8, 9 to 12) - r.testOp(inputData, (r: RDS[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) - } -} \ No newline at end of file -- cgit v1.2.3 From cae894ee7aefa4cf9b1952038a48be81e1d2a856 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 6 Aug 2012 14:52:46 -0700 Subject: Added new Clock interface that is used by RecurringTimer to scheduler events on system time or manually-configured time. --- .../src/main/scala/spark/streaming/DStream.scala | 1 - streaming/src/main/scala/spark/streaming/Job.scala | 9 ++- .../src/main/scala/spark/streaming/Scheduler.scala | 7 +- .../main/scala/spark/streaming/util/Clock.scala | 77 ++++++++++++++++++++++ .../spark/streaming/util/RecurringTimer.scala | 38 +++++++---- .../test/scala/spark/streaming/DStreamSuite.scala | 28 +++++--- 6 files changed, 130 insertions(+), 30 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/util/Clock.scala diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e19d2ecef5..c63c043415 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -197,7 +197,6 @@ extends Logging with Serializable { private[streaming] def toQueue = { val queue = new ArrayBlockingQueue[RDD[T]](10000) this.foreachRDD(rdd => { - println("Added RDD " + rdd.id) queue.add(rdd) }) queue diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala index 2481a9a3ef..0bd8343b9a 100644 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -1,5 +1,7 @@ package spark.streaming +import java.util.concurrent.atomic.AtomicLong + class Job(val time: Time, func: () => _) { val id = Job.getNewId() def run(): Long = { @@ -13,11 +15,8 @@ class Job(val time: Time, func: () => _) { } object Job { - var lastId = 1 + val id = new AtomicLong(0) - def getNewId() = synchronized { - lastId += 1 - lastId - } + def getNewId() = id.getAndIncrement() } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index fff4924b4c..309bd95525 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -1,6 +1,7 @@ package spark.streaming import spark.streaming.util.RecurringTimer +import spark.streaming.util.Clock import spark.SparkEnv import spark.Logging @@ -20,8 +21,10 @@ extends Logging { val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) - val timer = new RecurringTimer(ssc.batchDuration, generateRDDs(_)) - + val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") + val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] + val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_)) + def start() { val zeroTime = Time(timer.start()) diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala new file mode 100644 index 0000000000..72e786e0c3 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/Clock.scala @@ -0,0 +1,77 @@ +package spark.streaming.util + +import spark.streaming._ + +trait Clock { + def currentTime(): Long + def waitTillTime(targetTime: Long): Long +} + + +class SystemClock() extends Clock { + + val minPollTime = 25L + + def currentTime(): Long = { + System.currentTimeMillis() + } + + def waitTillTime(targetTime: Long): Long = { + var currentTime = 0L + currentTime = System.currentTimeMillis() + + var waitTime = targetTime - currentTime + if (waitTime <= 0) { + return currentTime + } + + val pollTime = { + if (waitTime / 10.0 > minPollTime) { + (waitTime / 10.0).toLong + } else { + minPollTime + } + } + + + while (true) { + currentTime = System.currentTimeMillis() + waitTime = targetTime - currentTime + + if (waitTime <= 0) { + + return currentTime + } + val sleepTime = + if (waitTime < pollTime) { + waitTime + } else { + pollTime + } + Thread.sleep(sleepTime) + } + return -1 + } +} + +class ManualClock() extends Clock { + + var time = 0L + + def currentTime() = time + + def addToTime(timeToAdd: Long) = { + this.synchronized { + time += timeToAdd + this.notifyAll() + } + } + def waitTillTime(targetTime: Long): Long = { + this.synchronized { + while (time < targetTime) { + this.wait(100) + } + } + return currentTime() + } +} diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index 6125bb82eb..5da9fa6ecc 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -1,6 +1,6 @@ package spark.streaming.util -class RecurringTimer(period: Long, callback: (Long) => Unit) { +class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { val minPollTime = 25L @@ -19,7 +19,7 @@ class RecurringTimer(period: Long, callback: (Long) => Unit) { var nextTime = 0L def start(): Long = { - nextTime = (math.floor(System.currentTimeMillis() / period) + 1).toLong * period + nextTime = (math.floor(clock.currentTime / period) + 1).toLong * period thread.start() nextTime } @@ -31,22 +31,32 @@ class RecurringTimer(period: Long, callback: (Long) => Unit) { def loop() { try { while (true) { - val beforeSleepTime = System.currentTimeMillis() - while (beforeSleepTime >= nextTime) { - callback(nextTime) - nextTime += period - } - val sleepTime = if (nextTime - beforeSleepTime < 2 * pollTime) { - nextTime - beforeSleepTime - } else { - pollTime - } - Thread.sleep(sleepTime) - val afterSleepTime = System.currentTimeMillis() + clock.waitTillTime(nextTime) + callback(nextTime) + nextTime += period } + } catch { case e: InterruptedException => } } } +object RecurringTimer { + + def main(args: Array[String]) { + var lastRecurTime = 0L + val period = 1000 + + def onRecur(time: Long) { + val currentTime = System.currentTimeMillis() + println("" + currentTime + ": " + (currentTime - lastRecurTime)) + lastRecurTime = currentTime + } + val timer = new RecurringTimer(new SystemClock(), period, onRecur) + timer.start() + Thread.sleep(30 * 1000) + timer.stop() + } +} + diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala index ce7c3d2e2b..2c10a03e6d 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala @@ -1,6 +1,8 @@ package spark.streaming -import spark.{Logging, RDD} +import spark.Logging +import spark.RDD +import spark.streaming.util.ManualClock import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter @@ -13,11 +15,13 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { var ssc: SparkStreamContext = null val batchDurationMillis = 1000 + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + def testOp[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], expectedOutput: Seq[Seq[V]]) { - try { + try { ssc = new SparkStreamContext("local", "test") ssc.setBatchDuration(Milliseconds(batchDurationMillis)) @@ -26,12 +30,14 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { val outputQueue = outputStream.toQueue ssc.start() - Thread.sleep(batchDurationMillis * input.size) + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.addToTime(input.size * batchDurationMillis) + + Thread.sleep(100) val output = new ArrayBuffer[Seq[V]]() while(outputQueue.size > 0) { - val rdd = outputQueue.take() - logInfo("Collecting RDD " + rdd.id + ", " + rdd.getClass.getSimpleName + ", " + rdd.splits.size) + val rdd = outputQueue.take() output += (rdd.collect()) } assert(output.size === expectedOutput.size) @@ -58,8 +64,14 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { object DStreamSuite { def main(args: Array[String]) { - val r = new DStreamSuite() - val inputData = Array(1 to 4, 5 to 8, 9 to 12) - r.testOp(inputData, (r: DStream[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + try { + val r = new DStreamSuite() + val inputData = Array(1 to 4, 5 to 8, 9 to 12) + r.testOp(inputData, (r: DStream[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) + + } catch { + case e: Exception => e.printStackTrace() + } + System.exit(0) } } \ No newline at end of file -- cgit v1.2.3 From 886b39de557b4d5f54f5ca11559fca9799534280 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 10 Aug 2012 01:10:02 -0700 Subject: Add Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 147 ++++++ pyspark/pyspark/__init__.py | 0 pyspark/pyspark/context.py | 69 +++ pyspark/pyspark/examples/__init__.py | 0 pyspark/pyspark/examples/kmeans.py | 56 +++ pyspark/pyspark/examples/pi.py | 20 + pyspark/pyspark/examples/tc.py | 49 ++ pyspark/pyspark/java_gateway.py | 20 + pyspark/pyspark/join.py | 104 +++++ pyspark/pyspark/rdd.py | 517 +++++++++++++++++++++ pyspark/pyspark/serializers.py | 229 +++++++++ pyspark/pyspark/worker.py | 97 ++++ pyspark/requirements.txt | 9 + python/tc.py | 22 + 14 files changed, 1339 insertions(+) create mode 100644 core/src/main/scala/spark/api/python/PythonRDD.scala create mode 100644 pyspark/pyspark/__init__.py create mode 100644 pyspark/pyspark/context.py create mode 100644 pyspark/pyspark/examples/__init__.py create mode 100644 pyspark/pyspark/examples/kmeans.py create mode 100644 pyspark/pyspark/examples/pi.py create mode 100644 pyspark/pyspark/examples/tc.py create mode 100644 pyspark/pyspark/java_gateway.py create mode 100644 pyspark/pyspark/join.py create mode 100644 pyspark/pyspark/rdd.py create mode 100644 pyspark/pyspark/serializers.py create mode 100644 pyspark/pyspark/worker.py create mode 100644 pyspark/requirements.txt create mode 100644 python/tc.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala new file mode 100644 index 0000000000..660ad48afe --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -0,0 +1,147 @@ +package spark.api.python + +import java.io.PrintWriter + +import scala.collection.Map +import scala.collection.JavaConversions._ +import scala.io.Source +import spark._ +import api.java.{JavaPairRDD, JavaRDD} +import scala.Some + +trait PythonRDDBase { + def compute[T](split: Split, envVars: Map[String, String], + command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[String]= { + val currentEnvVars = new ProcessBuilder().environment() + val SPARK_HOME = currentEnvVars.get("SPARK_HOME") + + val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) + // Add the environmental variables to the process. + envVars.foreach { + case (variable, value) => currentEnvVars.put(variable, value) + } + + val proc = pb.start() + val env = SparkEnv.get + + // Start a thread to print the process's stderr to ours + new Thread("stderr reader for " + command) { + override def run() { + for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + System.err.println(line) + } + } + }.start() + + // Start a thread to feed the process input from our parent's iterator + new Thread("stdin writer for " + command) { + override def run() { + SparkEnv.set(env) + val out = new PrintWriter(proc.getOutputStream) + for (elem <- command) { + out.println(elem) + } + for (elem <- parent.iterator(split)) { + out.println(PythonRDD.pythonDump(elem)) + } + out.close() + } + }.start() + + // Return an iterator that read lines from the process's stdout + val lines: Iterator[String] = Source.fromInputStream(proc.getInputStream).getLines + wrapIterator(lines, proc) + } + + def wrapIterator[T](iter: Iterator[T], proc: Process): Iterator[T] = { + return new Iterator[T] { + def next() = iter.next() + + def hasNext = { + if (iter.hasNext) { + true + } else { + val exitStatus = proc.waitFor() + if (exitStatus != 0) { + throw new Exception("Subprocess exited with status " + exitStatus) + } + false + } + } + } + } +} + +class PythonRDD[T: ClassManifest]( + parent: RDD[T], command: Seq[String], envVars: Map[String, String], + preservePartitoning: Boolean, pythonExec: String) + extends RDD[String](parent.context) with PythonRDDBase { + + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = + this(parent, command, Map(), preservePartitoning, pythonExec) + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + + override def splits = parent.splits + + override val dependencies = List(new OneToOneDependency(parent)) + + override val partitioner = if (preservePartitoning) parent.partitioner else None + + override def compute(split: Split): Iterator[String] = + compute(split, envVars, command, parent, pythonExec) + + val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +class PythonPairRDD[T: ClassManifest] ( + parent: RDD[T], command: Seq[String], envVars: Map[String, String], + preservePartitoning: Boolean, pythonExec: String) + extends RDD[(String, String)](parent.context) with PythonRDDBase { + + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = + this(parent, command, Map(), preservePartitoning, pythonExec) + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + + override def splits = parent.splits + + override val dependencies = List(new OneToOneDependency(parent)) + + override val partitioner = if (preservePartitoning) parent.partitioner else None + + override def compute(split: Split): Iterator[(String, String)] = { + compute(split, envVars, command, parent, pythonExec).grouped(2).map { + case Seq(a, b) => (a, b) + case x => throw new Exception("Unexpected value: " + x) + } + } + + val asJavaPairRDD : JavaPairRDD[String, String] = JavaPairRDD.fromRDD(this) +} + +object PythonRDD { + def pythonDump[T](x: T): String = { + if (x.isInstanceOf[scala.Option[_]]) { + val t = x.asInstanceOf[scala.Option[_]] + t match { + case None => "*" + case Some(z) => pythonDump(z) + } + } else if (x.isInstanceOf[scala.Tuple2[_, _]]) { + val t = x.asInstanceOf[scala.Tuple2[_, _]] + "(" + pythonDump(t._1) + "," + pythonDump(t._2) + ")" + } else if (x.isInstanceOf[java.util.List[_]]) { + val objs = asScalaBuffer(x.asInstanceOf[java.util.List[_]]).map(pythonDump) + "[" + objs.mkString("|") + "]" + } else { + x.toString + } + } +} diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py new file mode 100644 index 0000000000..587ab12b5f --- /dev/null +++ b/pyspark/pyspark/context.py @@ -0,0 +1,69 @@ +import os +import atexit +from tempfile import NamedTemporaryFile + +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import JSONSerializer, NopSerializer +from pyspark.rdd import RDD, PairRDD + + +class SparkContext(object): + + gateway = launch_gateway() + jvm = gateway.jvm + python_dump = jvm.spark.api.python.PythonRDD.pythonDump + + def __init__(self, master, name, defaultSerializer=JSONSerializer, + defaultParallelism=None, pythonExec='python'): + self.master = master + self.name = name + self._jsc = self.jvm.JavaSparkContext(master, name) + self.defaultSerializer = defaultSerializer + self.defaultParallelism = \ + defaultParallelism or self._jsc.sc().defaultParallelism() + self.pythonExec = pythonExec + + def __del__(self): + if self._jsc: + self._jsc.stop() + + def stop(self): + self._jsc.stop() + self._jsc = None + + def parallelize(self, c, numSlices=None, serializer=None): + serializer = serializer or self.defaultSerializer + numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). + tempFile = NamedTemporaryFile(delete=False) + tempFile.writelines(serializer.dumps(x) + '\n' for x in c) + tempFile.close() + atexit.register(lambda: os.unlink(tempFile.name)) + return self.textFile(tempFile.name, numSlices, serializer) + + def parallelizePairs(self, c, numSlices=None, keySerializer=None, + valSerializer=None): + """ + >>> sc = SparkContext("local", "test") + >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)]) + >>> rdd.collect() + [(1, 2), (3, 4)] + """ + keySerializer = keySerializer or self.defaultSerializer + valSerializer = valSerializer or self.defaultSerializer + numSlices = numSlices or self.defaultParallelism + tempFile = NamedTemporaryFile(delete=False) + for (k, v) in c: + tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n') + tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n') + tempFile.close() + atexit.register(lambda: os.unlink(tempFile.name)) + jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo") + return PairRDD(jrdd, self, keySerializer, valSerializer) + + def textFile(self, name, numSlices=None, serializer=NopSerializer): + numSlices = numSlices or self.defaultParallelism + jrdd = self._jsc.textFile(name, numSlices) + return RDD(jrdd, self, serializer) diff --git a/pyspark/pyspark/examples/__init__.py b/pyspark/pyspark/examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/pyspark/examples/kmeans.py new file mode 100644 index 0000000000..0761d6e395 --- /dev/null +++ b/pyspark/pyspark/examples/kmeans.py @@ -0,0 +1,56 @@ +import sys + +from pyspark.context import SparkContext + + +def parseVector(line): + return [float(x) for x in line.split(' ')] + + +def addVec(x, y): + return [a + b for (a, b) in zip(x, y)] + + +def squaredDist(x, y): + return sum((a - b) ** 2 for (a, b) in zip(x, y)) + + +def closestPoint(p, centers): + bestIndex = 0 + closest = float("+inf") + for i in range(len(centers)): + tempDist = squaredDist(p, centers[i]) + if tempDist < closest: + closest = tempDist + bestIndex = i + return bestIndex + + +if __name__ == "__main__": + if len(sys.argv) < 5: + print >> sys.stderr, \ + "Usage: PythonKMeans " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + lines = sc.textFile(sys.argv[2]) + data = lines.map(parseVector).cache() + K = int(sys.argv[3]) + convergeDist = float(sys.argv[4]) + + kPoints = data.takeSample(False, K, 34) + tempDist = 1.0 + + while tempDist > convergeDist: + closest = data.mapPairs( + lambda p : (closestPoint(p, kPoints), (p, 1))) + pointStats = closest.reduceByKey( + lambda (x1, y1), (x2, y2): (addVec(x1, x2), y1 + y2)) + newPoints = pointStats.mapPairs( + lambda (x, (y, z)): (x, [a / z for a in y])).collect() + + tempDist = sum(squaredDist(kPoints[x], y) for (x, y) in newPoints) + + for (x, y) in newPoints: + kPoints[x] = y + + print "Final centers: " + str(kPoints) diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py new file mode 100644 index 0000000000..ad77694c41 --- /dev/null +++ b/pyspark/pyspark/examples/pi.py @@ -0,0 +1,20 @@ +import sys +from random import random +from operator import add +from pyspark.context import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonPi []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + n = 100000 * slices + def f(_): + x = random() * 2 - 1 + y = random() * 2 - 1 + return 1 if x ** 2 + y ** 2 < 1 else 0 + count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + print "Pi is roughly %f" % (4.0 * count / n) diff --git a/pyspark/pyspark/examples/tc.py b/pyspark/pyspark/examples/tc.py new file mode 100644 index 0000000000..2796fdc6ad --- /dev/null +++ b/pyspark/pyspark/examples/tc.py @@ -0,0 +1,49 @@ +import sys +from random import Random +from pyspark.context import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelizePairs(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.mapPairs(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).mapPairs(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py new file mode 100644 index 0000000000..2df80aee85 --- /dev/null +++ b/pyspark/pyspark/java_gateway.py @@ -0,0 +1,20 @@ +import glob +import os +from py4j.java_gateway import java_import, JavaGateway + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \ + "/spark-core-assembly-*-SNAPSHOT.jar")[0] + + +def launch_gateway(): + gateway = JavaGateway.launch_gateway(classpath=assembly_jar, + javaopts=["-Xmx256m"], die_on_exit=True) + java_import(gateway.jvm, "spark.api.java.*") + java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "scala.Tuple2") + java_import(gateway.jvm, "spark.api.python.PythonRDD.pythonDump") + return gateway diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py new file mode 100644 index 0000000000..c67520fce8 --- /dev/null +++ b/pyspark/pyspark/join.py @@ -0,0 +1,104 @@ +""" +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +from pyspark.serializers import PairSerializer, OptionSerializer, \ + ArraySerializer + + +def _do_python_join(rdd, other, numSplits, dispatch, valSerializer): + vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) + ws = other.mapPairs(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits) \ + .flatMapValues(dispatch, valSerializer) + + +def python_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(rdd.valSerializer, other.valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_right_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not vbuf: + vbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(OptionSerializer(rdd.valSerializer), + other.valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_left_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not wbuf: + wbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(rdd.valSerializer, + OptionSerializer(other.valSerializer)) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_cogroup(rdd, other, numSplits): + resultValSerializer = PairSerializer( + ArraySerializer(rdd.valSerializer), + ArraySerializer(other.valSerializer)) + vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) + ws = other.mapPairs(lambda (k, v): (k, (2, v))) + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return (vbuf, wbuf) + return vs.union(ws).groupByKey(numSplits) \ + .mapValues(dispatch, resultValSerializer) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py new file mode 100644 index 0000000000..c892e86b93 --- /dev/null +++ b/pyspark/pyspark/rdd.py @@ -0,0 +1,517 @@ +from base64 import standard_b64encode as b64enc +from cloud.serialization import cloudpickle +from itertools import chain + +from pyspark.serializers import PairSerializer, NopSerializer, \ + OptionSerializer, ArraySerializer +from pyspark.join import python_join, python_left_outer_join, \ + python_right_outer_join, python_cogroup + + +class RDD(object): + + def __init__(self, jrdd, ctx, serializer=None): + self._jrdd = jrdd + self.is_cached = False + self.ctx = ctx + self.serializer = serializer or ctx.defaultSerializer + + def _builder(self, jrdd, ctx): + return RDD(jrdd, ctx, self.serializer) + + @property + def id(self): + return self._jrdd.id() + + @property + def splits(self): + return self._jrdd.splits() + + @classmethod + def _get_pipe_command(cls, command, functions): + if functions and not isinstance(functions, (list, tuple)): + functions = [functions] + worker_args = [command] + for f in functions: + worker_args.append(b64enc(cloudpickle.dumps(f))) + return " ".join(worker_args) + + def cache(self): + self.is_cached = True + self._jrdd.cache() + return self + + def map(self, f, serializer=None, preservesPartitioning=False): + return MappedRDD(self, f, serializer, preservesPartitioning) + + def mapPairs(self, f, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + return PairMappedRDD(self, f, keySerializer, valSerializer, + preservesPartitioning) + + def flatMap(self, f, serializer=None): + """ + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) + [1, 1, 1, 2, 2, 3] + """ + serializer = serializer or self.ctx.defaultSerializer + dumps = serializer.dumps + loads = self.serializer.loads + def func(x): + pickled_elems = (dumps(y) for y in f(loads(x))) + return "\n".join(pickled_elems) or None + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._jrdd.classManifest() + jrdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, + False, self.ctx.pythonExec, + class_manifest).asJavaRDD() + return RDD(jrdd, self.ctx, serializer) + + def flatMapPairs(self, f, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + """ + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMapPairs(lambda x: [(x, x), (x, x)]).collect()) + [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] + """ + keySerializer = keySerializer or self.ctx.defaultSerializer + valSerializer = valSerializer or self.ctx.defaultSerializer + dumpk = keySerializer.dumps + dumpv = valSerializer.dumps + loads = self.serializer.loads + def func(x): + pairs = f(loads(x)) + pickled_pairs = ((dumpk(k), dumpv(v)) for (k, v) in pairs) + return "\n".join(chain.from_iterable(pickled_pairs)) or None + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, + preservesPartitioning, self.ctx.pythonExec, class_manifest) + return PairRDD(python_rdd.asJavaPairRDD(), self.ctx, keySerializer, + valSerializer) + + def filter(self, f): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) + >>> rdd.filter(lambda x: x % 2 == 0).collect() + [2, 4] + """ + loads = self.serializer.loads + def filter_func(x): return x if f(loads(x)) else None + return self._builder(self._pipe(filter_func), self.ctx) + + def _pipe(self, functions, command="map"): + class_manifest = self._jrdd.classManifest() + pipe_command = RDD._get_pipe_command(command, functions) + python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, + False, self.ctx.pythonExec, class_manifest) + return python_rdd.asJavaRDD() + + def _pipePairs(self, functions, command="mapPairs", + preservesPartitioning=False): + class_manifest = self._jrdd.classManifest() + pipe_command = RDD._get_pipe_command(command, functions) + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, + preservesPartitioning, self.ctx.pythonExec, class_manifest) + return python_rdd.asJavaPairRDD() + + def distinct(self): + """ + >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) + [1, 2, 3] + """ + if self.serializer.is_comparable: + return self._builder(self._jrdd.distinct(), self.ctx) + return self.mapPairs(lambda x: (x, "")) \ + .reduceByKey(lambda x, _: x) \ + .map(lambda (x, _): x) + + def sample(self, withReplacement, fraction, seed): + jrdd = self._jrdd.sample(withReplacement, fraction, seed) + return self._builder(jrdd, self.ctx) + + def takeSample(self, withReplacement, num, seed): + vals = self._jrdd.takeSample(withReplacement, num, seed) + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def union(self, other): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> rdd.union(rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + return self._builder(self._jrdd.union(other._jrdd), self.ctx) + + # TODO: sort + + # TODO: Overload __add___? + + # TODO: glom + + def cartesian(self, other): + """ + >>> rdd = sc.parallelize([1, 2]) + >>> sorted(rdd.cartesian(rdd).collect()) + [(1, 1), (1, 2), (2, 1), (2, 2)] + """ + return PairRDD(self._jrdd.cartesian(other._jrdd), self.ctx) + + # numsplits + def groupBy(self, f, numSplits=None): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) + >>> sorted(rdd.groupBy(lambda x: x % 2).collect()) + [(0, [2, 8]), (1, [1, 1, 3, 5])] + """ + return self.mapPairs(lambda x: (f(x), x)).groupByKey(numSplits) + + # TODO: pipe + + # TODO: mapPartitions + + def foreach(self, f): + """ + >>> def f(x): print x + >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) + """ + self.map(f).collect() # Force evaluation + + def collect(self): + vals = self._jrdd.collect() + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def reduce(self, f, serializer=None): + """ + >>> import operator + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(operator.add) + 15 + """ + serializer = serializer or self.ctx.defaultSerializer + loads = self.serializer.loads + dumps = serializer.dumps + def reduceFunction(x, acc): + if acc is None: + return loads(x) + else: + return f(loads(x), acc) + vals = self._pipe([reduceFunction, dumps], command="reduce").collect() + return reduce(f, (serializer.loads(x) for x in vals)) + + # TODO: fold + + # TODO: aggregate + + def count(self): + """ + >>> sc.parallelize([2, 3, 4]).count() + 3L + """ + return self._jrdd.count() + + # TODO: count approx methods + + def take(self, num): + """ + >>> sc.parallelize([2, 3, 4]).take(2) + [2, 3] + """ + vals = self._jrdd.take(num) + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def first(self): + """ + >>> sc.parallelize([2, 3, 4]).first() + 2 + """ + return self.serializer.loads(self.ctx.python_dump(self._jrdd.first())) + + # TODO: saveAsTextFile + + # TODO: saveAsObjectFile + + +class PairRDD(RDD): + + def __init__(self, jrdd, ctx, keySerializer=None, valSerializer=None): + RDD.__init__(self, jrdd, ctx) + self.keySerializer = keySerializer or ctx.defaultSerializer + self.valSerializer = valSerializer or ctx.defaultSerializer + self.serializer = \ + PairSerializer(self.keySerializer, self.valSerializer) + + def _builder(self, jrdd, ctx): + return PairRDD(jrdd, ctx, self.keySerializer, self.valSerializer) + + def reduceByKey(self, func, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.reduceByKey(lambda a, b: a + b).collect()) + [('a', 2), ('b', 1)] + """ + return self.combineByKey(lambda x: x, func, func, numSplits) + + # TODO: reduceByKeyLocally() + + # TODO: countByKey() + + # TODO: partitionBy + + def join(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2), ("a", 3)]) + >>> x.join(y).collect() + [('a', (1, 2)), ('a', (1, 3))] + + Check that we get a PairRDD-like object back: + >>> assert x.join(y).join + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.join(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(self.valSerializer, other.valSerializer)) + else: + return python_join(self, other, numSplits) + + def leftOuterJoin(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> sorted(x.leftOuterJoin(y).collect()) + [('a', (1, 2)), ('b', (4, None))] + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.leftOuterJoin(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(self.valSerializer, + OptionSerializer(other.valSerializer))) + else: + return python_left_outer_join(self, other, numSplits) + + def rightOuterJoin(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> sorted(y.rightOuterJoin(x).collect()) + [('a', (2, 1)), ('b', (None, 4))] + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.rightOuterJoin(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(OptionSerializer(self.valSerializer), + other.valSerializer)) + else: + return python_right_outer_join(self, other, numSplits) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numSplits=None, serializer=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> def f(x): return x + >>> def add(a, b): return a + str(b) + >>> sorted(x.combineByKey(str, add, add).collect()) + [('a', '11'), ('b', '1')] + """ + serializer = serializer or self.ctx.defaultSerializer + if numSplits is None: + numSplits = self.ctx.defaultParallelism + # Use hash() to create keys that are comparable in Java. + loadkv = self.serializer.loads + def pairify(kv): + # TODO: add method to deserialize only the key or value from + # a PairSerializer? + key = loadkv(kv)[0] + return (str(hash(key)), kv) + partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + jrdd = self._pipePairs(pairify).partitionBy(partitioner) + pairified = PairRDD(jrdd, self.ctx, NopSerializer, self.serializer) + + loads = PairSerializer(NopSerializer, self.serializer).loads + dumpk = self.keySerializer.dumps + dumpc = serializer.dumps + + functions = [createCombiner, mergeValue, mergeCombiners, loads, dumpk, + dumpc] + jpairs = pairified._pipePairs(functions, "combine_by_key", + preservesPartitioning=True) + return PairRDD(jpairs, self.ctx, self.keySerializer, serializer) + + def groupByKey(self, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.groupByKey().collect()) + [('a', [1, 1]), ('b', [1])] + """ + + def createCombiner(x): + return [x] + + def mergeValue(xs, x): + xs.append(x) + return xs + + def mergeCombiners(a, b): + return a + b + + return self.combineByKey(createCombiner, mergeValue, mergeCombiners, + numSplits) + + def collectAsMap(self): + """ + >>> m = sc.parallelizePairs([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + m = self._jrdd.collectAsMap() + def loads(x): + (k, v) = x + return (self.keySerializer.loads(k), self.valSerializer.loads(v)) + return dict(loads(x) for x in m.items()) + + def flatMapValues(self, f, valSerializer=None): + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMapPairs(flat_map_fn, self.keySerializer, + valSerializer, True) + + def mapValues(self, f, valSerializer=None): + map_values_fn = lambda (k, v): (k, f(v)) + return self.mapPairs(map_values_fn, self.keySerializer, valSerializer, + True) + + # TODO: support varargs cogroup of several RDDs. + def groupWith(self, other): + return self.cogroup(other) + + def cogroup(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> x.cogroup(y).collect() + [('a', ([1], [2])), ('b', ([4], []))] + """ + assert self.keySerializer.name == other.keySerializer.name + resultValSerializer = PairSerializer( + ArraySerializer(self.valSerializer), + ArraySerializer(other.valSerializer)) + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.cogroup(other._jrdd), + self.ctx, self.keySerializer, resultValSerializer) + else: + return python_cogroup(self, other, numSplits) + + # TODO: `lookup` is disabled because we can't make direct comparisons based + # on the key; we need to compare the hash of the key to the hash of the + # keys in the pairs. This could be an expensive operation, since those + # hashes aren't retained. + + # TODO: file saving + + +class MappedRDDBase(object): + def __init__(self, prev, func, serializer, preservesPartitioning=False): + if isinstance(prev, MappedRDDBase) and not prev.is_cached: + prev_func = prev.func + self.func = lambda x: func(prev_func(x)) + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jrdd = prev._prev_jrdd + self._prev_serializer = prev._prev_serializer + else: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_serializer = prev.serializer + self.serializer = serializer or prev.ctx.defaultSerializer + self.is_cached = False + self.ctx = prev.ctx + self.prev = prev + self._jrdd_val = None + + +class MappedRDD(MappedRDDBase, RDD): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + """ + + @property + def _jrdd(self): + if not self._jrdd_val: + udf = self.func + loads = self._prev_serializer.loads + dumps = self.serializer.dumps + func = lambda x: dumps(udf(loads(x))) + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._prev_jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() + return self._jrdd_val + + +class PairMappedRDD(MappedRDDBase, PairRDD): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.mapPairs(lambda x: (x, x)) \\ + ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ + ... .collect() + [(2, 2), (4, 4), (6, 6), (8, 8)] + >>> rdd.mapPairs(lambda x: (x, x)) \\ + ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ + ... .map(lambda (x, _): x).collect() + [2, 4, 6, 8] + """ + + def __init__(self, prev, func, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + self.keySerializer = keySerializer or prev.ctx.defaultSerializer + self.valSerializer = valSerializer or prev.ctx.defaultSerializer + serializer = PairSerializer(self.keySerializer, self.valSerializer) + MappedRDDBase.__init__(self, prev, func, serializer, + preservesPartitioning) + + @property + def _jrdd(self): + if not self._jrdd_val: + udf = self.func + loads = self._prev_serializer.loads + dumpk = self.keySerializer.dumps + dumpv = self.valSerializer.dumps + def func(x): + (k, v) = udf(loads(x)) + return (dumpk(k), dumpv(v)) + pipe_command = RDD._get_pipe_command("mapPairs", [func]) + class_manifest = self._prev_jrdd.classManifest() + self._jrdd_val = self.ctx.jvm.PythonPairRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + class_manifest).asJavaPairRDD() + return self._jrdd_val + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.serializers import PickleSerializer, JSONSerializer + globs = globals().copy() + globs['sc'] = SparkContext('local', 'PythonTest', + defaultSerializer=JSONSerializer) + doctest.testmod(globs=globs) + globs['sc'].stop() + globs['sc'] = SparkContext('local', 'PythonTest', + defaultSerializer=PickleSerializer) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py new file mode 100644 index 0000000000..b113f5656b --- /dev/null +++ b/pyspark/pyspark/serializers.py @@ -0,0 +1,229 @@ +""" +Data serialization methods. + +The Spark Python API is built on top of the Spark Java API. RDDs created in +Python are stored in Java as RDDs of Strings. Python objects are automatically +serialized/deserialized, so this representation is transparent to the end-user. + +------------------ +Serializer objects +------------------ + +`Serializer` objects are used to customize how an RDD's values are serialized. + +Each `Serializer` is a named tuple with four fields: + + - A `dumps` function, for serializing a Python object to a string. + + - A `loads` function, for deserializing a Python object from a string. + + - An `is_comparable` field, True if equal Python objects are serialized to + equal strings, and False otherwise. + + - A `name` field, used to identify the Serializer. Serializers are + compared for equality by comparing their names. + +The serializer's output should be base64-encoded. + +------------------------------------------------------------------ +`is_comparable`: comparing serialized representations for equality +------------------------------------------------------------------ + +If `is_comparable` is False, the serializer's representations of equal objects +are not required to be equal: + +>>> import pickle +>>> a = {1: 0, 9: 0} +>>> b = {9: 0, 1: 0} +>>> a == b +True +>>> pickle.dumps(a) == pickle.dumps(b) +False + +RDDs with comparable serializers can use native Java implementations of +operations like join() and distinct(), which may lead to better performance by +eliminating deserialization and Python comparisons. + +The default JSONSerializer produces comparable representations of common Python +data structures. + +-------------------------------------- +Examples of serialized representations +-------------------------------------- + +The RDD transformations that use Python UDFs are implemented in terms of +a modified `PipedRDD.pipe()` function. For each record `x` in the RDD, the +`pipe()` function pipes `x.toString()` to a Python worker process, which +deserializes the string into a Python object, executes user-defined functions, +and outputs serialized Python objects. + +The regular `toString()` method returns an ambiguous representation, due to the +way that Scala `Option` instances are printed: + +>>> from context import SparkContext +>>> sc = SparkContext("local", "SerializerDocs") +>>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) +>>> y = sc.parallelizePairs([("a", 2)]) + +>>> print y.rightOuterJoin(x)._jrdd.first().toString() +(ImEi,(Some(Mg==),MQ==)) + +In Java, preprocessing is performed to handle Option instances, so the Python +process receives unambiguous input: + +>>> print sc.python_dump(y.rightOuterJoin(x)._jrdd.first()) +(ImEi,(Mg==,MQ==)) + +The base64-encoding eliminates the need to escape newlines, parentheses and +other special characters. + +---------------------- +Serializer composition +---------------------- + +In order to handle nested structures, which could contain object serialized +with different serializers, the RDD module composes serializers. For example, +the serializers in the previous example are: + +>>> print x.serializer.name +PairSerializer + +>>> print y.serializer.name +PairSerializer + +>>> print y.rightOuterJoin(x).serializer.name +PairSerializer, JSONSerializer>> +""" +from base64 import standard_b64encode, standard_b64decode +from collections import namedtuple +import cPickle +import simplejson + + +Serializer = namedtuple("Serializer", + ["dumps","loads", "is_comparable", "name"]) + + +NopSerializer = Serializer(str, str, True, "NopSerializer") + + +JSONSerializer = Serializer( + lambda obj: standard_b64encode(simplejson.dumps(obj, sort_keys=True, + separators=(',', ':'))), + lambda s: simplejson.loads(standard_b64decode(s)), + True, + "JSONSerializer" +) + + +PickleSerializer = Serializer( + lambda obj: standard_b64encode(cPickle.dumps(obj)), + lambda s: cPickle.loads(standard_b64decode(s)), + False, + "PickleSerializer" +) + + +def OptionSerializer(serializer): + """ + >>> ser = OptionSerializer(NopSerializer) + >>> ser.loads(ser.dumps("Hello, World!")) + 'Hello, World!' + >>> ser.loads(ser.dumps(None)) is None + True + """ + none_placeholder = '*' + + def dumps(x): + if x is None: + return none_placeholder + else: + return serializer.dumps(x) + + def loads(x): + if x == none_placeholder: + return None + else: + return serializer.loads(x) + + name = "OptionSerializer<%s>" % serializer.name + return Serializer(dumps, loads, serializer.is_comparable, name) + + +def PairSerializer(keySerializer, valSerializer): + """ + Returns a Serializer for a (key, value) pair. + + >>> ser = PairSerializer(JSONSerializer, JSONSerializer) + >>> ser.loads(ser.dumps((1, 2))) + (1, 2) + + >>> ser = PairSerializer(JSONSerializer, ser) + >>> ser.loads(ser.dumps((1, (2, 3)))) + (1, (2, 3)) + """ + def loads(kv): + try: + (key, val) = kv[1:-1].split(',', 1) + key = keySerializer.loads(key) + val = valSerializer.loads(val) + return (key, val) + except: + print "Error in deserializing pair from '%s'" % str(kv) + raise + + def dumps(kv): + (key, val) = kv + return"(%s,%s)" % (keySerializer.dumps(key), valSerializer.dumps(val)) + is_comparable = \ + keySerializer.is_comparable and valSerializer.is_comparable + name = "PairSerializer<%s, %s>" % (keySerializer.name, valSerializer.name) + return Serializer(dumps, loads, is_comparable, name) + + +def ArraySerializer(serializer): + """ + >>> ser = ArraySerializer(JSONSerializer) + >>> ser.loads(ser.dumps([1, 2, 3, 4])) + [1, 2, 3, 4] + >>> ser = ArraySerializer(PairSerializer(JSONSerializer, PickleSerializer)) + >>> ser.loads(ser.dumps([('a', 1), ('b', 2)])) + [('a', 1), ('b', 2)] + >>> ser.loads(ser.dumps([('a', 1)])) + [('a', 1)] + >>> ser.loads(ser.dumps([])) + [] + """ + def dumps(arr): + if arr == []: + return '[]' + else: + return '[' + '|'.join(serializer.dumps(x) for x in arr) + ']' + + def loads(s): + if s == '[]': + return [] + items = s[1:-1] + if '|' in items: + items = items.split('|') + else: + items = [items] + return [serializer.loads(x) for x in items] + + name = "ArraySerializer<%s>" % serializer.name + return Serializer(dumps, loads, serializer.is_comparable, name) + + +# TODO: IntegerSerializer + + +# TODO: DoubleSerializer + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py new file mode 100644 index 0000000000..4d4cc939c3 --- /dev/null +++ b/pyspark/pyspark/worker.py @@ -0,0 +1,97 @@ +""" +Worker that receives input from Piped RDD. +""" +import sys +from base64 import standard_b64decode +# CloudPickler needs to be imported so that depicklers are registered using the +# copy_reg module. +from cloud.serialization.cloudpickle import CloudPickler +import cPickle + + +# Redirect stdout to stderr so that users must return values from functions. +old_stdout = sys.stdout +sys.stdout = sys.stderr + + +def load_function(): + return cPickle.loads(standard_b64decode(sys.stdin.readline().strip())) + + +def output(x): + for line in x.split("\n"): + old_stdout.write(line.rstrip("\r\n") + "\n") + + +def read_input(): + for line in sys.stdin: + yield line.rstrip("\r\n") + + +def do_combine_by_key(): + create_combiner = load_function() + merge_value = load_function() + merge_combiners = load_function() # TODO: not used. + depickler = load_function() + key_pickler = load_function() + combiner_pickler = load_function() + combiners = {} + for line in read_input(): + # Discard the hashcode added in the Python combineByKey() method. + (key, value) = depickler(line)[1] + if key not in combiners: + combiners[key] = create_combiner(value) + else: + combiners[key] = merge_value(combiners[key], value) + for (key, combiner) in combiners.iteritems(): + output(key_pickler(key)) + output(combiner_pickler(combiner)) + + +def do_map(map_pairs=False): + f = load_function() + for line in read_input(): + try: + out = f(line) + if out is not None: + if map_pairs: + for x in out: + output(x) + else: + output(out) + except: + sys.stderr.write("Error processing line '%s'\n" % line) + raise + + +def do_reduce(): + f = load_function() + dumps = load_function() + acc = None + for line in read_input(): + acc = f(line, acc) + output(dumps(acc)) + + +def do_echo(): + old_stdout.writelines(sys.stdin.readlines()) + + +def main(): + command = sys.stdin.readline().strip() + if command == "map": + do_map(map_pairs=False) + elif command == "mapPairs": + do_map(map_pairs=True) + elif command == "combine_by_key": + do_combine_by_key() + elif command == "reduce": + do_reduce() + elif command == "echo": + do_echo() + else: + raise Exception("Unsupported command %s" % command) + + +if __name__ == '__main__': + main() diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt new file mode 100644 index 0000000000..d9b3fe40bd --- /dev/null +++ b/pyspark/requirements.txt @@ -0,0 +1,9 @@ +# The Python API relies on some new features from the Py4J development branch. +# pip can't install Py4J from git because the setup.py file for the Python +# package is not at the root of the git repository. It may be possible to +# install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. + +# git+git://github.com/bartdag/py4j.git@3dbf380d3d2cdeb9aab394454ea74d80c4aba1ea + +simplejson==2.6.1 +cloud==2.5.5 diff --git a/python/tc.py b/python/tc.py new file mode 100644 index 0000000000..5dcc4317e0 --- /dev/null +++ b/python/tc.py @@ -0,0 +1,22 @@ +from rdd import SparkContext + +sc = SparkContext("local", "PythonWordCount") +e = [(1, 2), (2, 3), (4, 1)] + +tc = sc.parallelizePairs(e) + +edges = tc.mapPairs(lambda (x, y): (y, x)) + +oldCount = 0 +nextCount = tc.count() + +def project(x): + return (x[1][1], x[1][0]) + +while nextCount != oldCount: + oldCount = nextCount + tc = tc.union(tc.join(edges).mapPairs(project)).distinct() + nextCount = tc.count() + +print "TC has %i edges" % tc.count() +print tc.collect() -- cgit v1.2.3 From 13b9514966a423f80f672f23f42ec3f0113936fd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 19 Aug 2012 17:12:51 -0700 Subject: Bundle cloudpickle with pyspark. --- pyspark/pyspark/cloudpickle.py | 974 +++++++++++++++++++++++++++++++++++++++++ pyspark/pyspark/rdd.py | 2 +- pyspark/pyspark/worker.py | 2 +- pyspark/requirements.txt | 3 - 4 files changed, 976 insertions(+), 5 deletions(-) create mode 100644 pyspark/pyspark/cloudpickle.py diff --git a/pyspark/pyspark/cloudpickle.py b/pyspark/pyspark/cloudpickle.py new file mode 100644 index 0000000000..6a7c23a069 --- /dev/null +++ b/pyspark/pyspark/cloudpickle.py @@ -0,0 +1,974 @@ +""" +This class is defined to override standard pickle functionality + +The goals of it follow: +-Serialize lambdas and nested functions to compiled byte code +-Deal with main module correctly +-Deal with other non-serializable objects + +It does not include an unpickler, as standard python unpickling suffices. + +This module was extracted from the `cloud` package, developed by `PiCloud, Inc. +`_. + +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +import operator +import os +import pickle +import struct +import sys +import types +from functools import partial +import itertools +from copy_reg import _extension_registry, _inverted_registry, _extension_cache +import new +import dis +import traceback + +#relevant opcodes +STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) +DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) +LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] + +HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) +EXTENDED_ARG = chr(dis.EXTENDED_ARG) + +import logging +cloudLog = logging.getLogger("Cloud.Transport") + +try: + import ctypes +except (MemoryError, ImportError): + logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) + ctypes = None + PyObject_HEAD = None +else: + + # for reading internal structures + PyObject_HEAD = [ + ('ob_refcnt', ctypes.c_size_t), + ('ob_type', ctypes.c_void_p), + ] + + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +# These helper functions were copied from PiCloud's util module. +def islambda(func): + return getattr(func,'func_name') == '' + +def xrange_params(xrangeobj): + """Returns a 3 element tuple describing the xrange start, step, and len + respectively + + Note: Only guarentees that elements of xrange are the same. parameters may + be different. + e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same + though w/ iteration + """ + + xrange_len = len(xrangeobj) + if not xrange_len: #empty + return (0,1,0) + start = xrangeobj[0] + if xrange_len == 1: #one element + return start, 1, 1 + return (start, xrangeobj[1] - xrangeobj[0], xrange_len) + +#debug variables intended for developer use: +printSerialization = False +printMemoization = False + +useForcedImports = True #Should I use forced imports for tracking? + + + +class CloudPickler(pickle.Pickler): + + dispatch = pickle.Pickler.dispatch.copy() + savedForceImports = False + savedDjangoEnv = False #hack tro transport django environment + + def __init__(self, file, protocol=None, min_size_to_save= 0): + pickle.Pickler.__init__(self,file,protocol) + self.modules = set() #set of modules needed to depickle + self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env + + def dump(self, obj): + # note: not thread safe + # minimal side-effects, so not fixing + recurse_limit = 3000 + base_recurse = sys.getrecursionlimit() + if base_recurse < recurse_limit: + sys.setrecursionlimit(recurse_limit) + self.inject_addons() + try: + return pickle.Pickler.dump(self, obj) + except RuntimeError, e: + if 'recursion' in e.args[0]: + msg = """Could not pickle object as excessively deep recursion required. + Try _fast_serialization=2 or contact PiCloud support""" + raise pickle.PicklingError(msg) + finally: + new_recurse = sys.getrecursionlimit() + if new_recurse == recurse_limit: + sys.setrecursionlimit(base_recurse) + + def save_buffer(self, obj): + """Fallback to save_string""" + pickle.Pickler.save_string(self,str(obj)) + dispatch[buffer] = save_buffer + + #block broken objects + def save_unsupported(self, obj, pack=None): + raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) + dispatch[types.GeneratorType] = save_unsupported + + #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it + try: + slice(0,1).__reduce__() + except TypeError: #can't pickle - + dispatch[slice] = save_unsupported + + #itertools objects do not pickle! + for v in itertools.__dict__.values(): + if type(v) is type: + dispatch[v] = save_unsupported + + + def save_dict(self, obj): + """hack fix + If the dict is a global, deal with it in a special way + """ + #print 'saving', obj + if obj is __builtins__: + self.save_reduce(_get_module_builtins, (), obj=obj) + else: + pickle.Pickler.save_dict(self, obj) + dispatch[pickle.DictionaryType] = save_dict + + + def save_module(self, obj, pack=struct.pack): + """ + Save a module as an import + """ + #print 'try save import', obj.__name__ + self.modules.add(obj) + self.save_reduce(subimport,(obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module #new type + + def save_codeobject(self, obj, pack=struct.pack): + """ + Save a code object + """ + #print 'try to save codeobj: ', obj + args = ( + obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, + obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars + ) + self.save_reduce(types.CodeType, args, obj=obj) + dispatch[types.CodeType] = save_codeobject #new type + + def save_function(self, obj, name=None, pack=struct.pack): + """ Registered with the dispatch to handle all function types. + + Determines what kind of function obj is (e.g. lambda, defined at + interactive prompt, etc) and handles the pickling appropriately. + """ + write = self.write + + name = obj.__name__ + modname = pickle.whichmodule(obj, name) + #print 'which gives %s %s %s' % (modname, obj, name) + try: + themodule = sys.modules[modname] + except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ + modname = '__main__' + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + if not self.savedDjangoEnv: + #hack for django - if we detect the settings module, we transport it + django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') + if django_settings: + django_mod = sys.modules.get(django_settings) + if django_mod: + cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) + self.savedDjangoEnv = True + self.modules.add(django_mod) + write(pickle.MARK) + self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) + write(pickle.POP_MARK) + + + # if func is lambda, def'ed at prompt, is in main, or is nested, then + # we'll pickle the actual function object rather than simply saving a + # reference (as is done in default pickler), via save_function_tuple. + if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: + #Force server to import modules that have been imported in main + modList = None + if themodule == None and not self.savedForceImports: + mainmod = sys.modules['__main__'] + if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): + modList = list(mainmod.___pyc_forcedImports__) + self.savedForceImports = True + self.save_function_tuple(obj, modList) + return + else: # func is nested + klass = getattr(themodule, name, None) + if klass is None or klass is not obj: + self.save_function_tuple(obj, [themodule]) + return + + if obj.__dict__: + # essentially save_reduce, but workaround needed to avoid recursion + self.save(_restore_attr) + write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + self.save(obj.__dict__) + write(pickle.TUPLE + pickle.REDUCE) + else: + write(pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + dispatch[types.FunctionType] = save_function + + def save_function_tuple(self, func, forced_imports): + """ Pickles an actual func object. + + A func comprises: code, globals, defaults, closure, and dict. We + extract and save these, injecting reducing functions at certain points + to recreate the func object. Keep in mind that some of these pieces + can contain a ref to the func itself. Thus, a naive save on these + pieces could trigger an infinite loop of save's. To get around that, + we first create a skeleton func object using just the code (this is + safe, since this won't contain a ref to the func), and memoize it as + soon as it's created. The other stuff can then be filled in later. + """ + save = self.save + write = self.write + + # save the modules (if any) + if forced_imports: + write(pickle.MARK) + save(_modules_to_main) + #print 'forced imports are', forced_imports + + forced_names = map(lambda m: m.__name__, forced_imports) + save((forced_names,)) + + #save((forced_imports,)) + write(pickle.REDUCE) + write(pickle.POP_MARK) + + code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + + save(_fill_function) # skeleton function updater + write(pickle.MARK) # beginning of tuple that _fill_function expects + + # create a skeleton function object and memoize it + save(_make_skel_func) + save((code, len(closure), base_globals)) + write(pickle.REDUCE) + self.memoize(func) + + # save the rest of the func data needed by _fill_function + save(f_globals) + save(defaults) + save(closure) + save(dct) + write(pickle.TUPLE) + write(pickle.REDUCE) # applies _fill_function on the tuple + + @staticmethod + def extract_code_globals(co): + """ + Find all globals names read or written to by codeblock co + """ + code = co.co_code + names = co.co_names + out_names = set() + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + + i = i+1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + extended_arg = 0 + i = i+2 + if op == EXTENDED_ARG: + extended_arg = oparg*65536L + if op in GLOBAL_OPS: + out_names.add(names[oparg]) + #print 'extracted', out_names, ' from ', names + return out_names + + def extract_func_data(self, func): + """ + Turn the function into a tuple of data necessary to recreate it: + code, globals, defaults, closure, dict + """ + code = func.func_code + + # extract all global ref's + func_global_refs = CloudPickler.extract_code_globals(code) + if code.co_consts: # see if nested function have any global refs + for const in code.co_consts: + if type(const) is types.CodeType and const.co_names: + func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) + # process all variables referenced by global environment + f_globals = {} + for var in func_global_refs: + #Some names, such as class functions are not global - we don't need them + if func.func_globals.has_key(var): + f_globals[var] = func.func_globals[var] + + # defaults requires no processing + defaults = func.func_defaults + + def get_contents(cell): + try: + return cell.cell_contents + except ValueError, e: #cell is empty error on not yet assigned + raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') + + + # process closure + if func.func_closure: + closure = map(get_contents, func.func_closure) + else: + closure = [] + + # save the dict + dct = func.func_dict + + if printSerialization: + outvars = ['code: ' + str(code) ] + outvars.append('globals: ' + str(f_globals)) + outvars.append('defaults: ' + str(defaults)) + outvars.append('closure: ' + str(closure)) + print 'function ', func, 'is extracted to: ', ', '.join(outvars) + + base_globals = self.globals_ref.get(id(func.func_globals), {}) + self.globals_ref[id(func.func_globals)] = base_globals + + return (code, f_globals, defaults, closure, dct, base_globals) + + def save_global(self, obj, name=None, pack=struct.pack): + write = self.write + memo = self.memo + + if name is None: + name = obj.__name__ + + modname = getattr(obj, "__module__", None) + if modname is None: + modname = pickle.whichmodule(obj, name) + + try: + __import__(modname) + themodule = sys.modules[modname] + except (ImportError, KeyError, AttributeError): #should never occur + raise pickle.PicklingError( + "Can't pickle %r: Module %s cannot be found" % + (obj, modname)) + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + sendRef = True + typ = type(obj) + #print 'saving', obj, typ + try: + try: #Deal with case when getattribute fails with exceptions + klass = getattr(themodule, name) + except (AttributeError): + if modname == '__builtin__': #new.* are misrepeported + modname = 'new' + __import__(modname) + themodule = sys.modules[modname] + try: + klass = getattr(themodule, name) + except AttributeError, a: + #print themodule, name, obj, type(obj) + raise pickle.PicklingError("Can't pickle builtin %s" % obj) + else: + raise + + except (ImportError, KeyError, AttributeError): + if typ == types.TypeType or typ == types.ClassType: + sendRef = False + else: #we can't deal with this + raise + else: + if klass is not obj and (typ == types.TypeType or typ == types.ClassType): + sendRef = False + if not sendRef: + #note: Third party types might crash this - add better checks! + d = dict(obj.__dict__) #copy dict proxy to a dict + if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties + d.pop('__dict__',None) + d.pop('__weakref__',None) + + # hack as __new__ is stored differently in the __dict__ + new_override = d.get('__new__', None) + if new_override: + d['__new__'] = obj.__new__ + + self.save_reduce(type(obj),(obj.__name__,obj.__bases__, + d),obj=obj) + #print 'internal reduce dask %s %s' % (obj, d) + return + + if self.proto >= 2: + code = _extension_registry.get((modname, name)) + if code: + assert code > 0 + if code <= 0xff: + write(pickle.EXT1 + chr(code)) + elif code <= 0xffff: + write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) + else: + write(pickle.EXT4 + pack("= 2 and getattr(func, "__name__", "") == "__newobj__": + #Added fix to allow transient + cls = args[0] + if not hasattr(cls, "__new__"): + raise pickle.PicklingError( + "args[0] from __newobj__ args has no __new__") + if obj is not None and cls is not obj.__class__: + raise pickle.PicklingError( + "args[0] from __newobj__ args has the wrong class") + args = args[1:] + save(cls) + + #Don't pickle transient entries + if hasattr(obj, '__transient__'): + transient = obj.__transient__ + state = state.copy() + + for k in list(state.keys()): + if k in transient: + del state[k] + + save(args) + write(pickle.NEWOBJ) + else: + save(func) + save(args) + write(pickle.REDUCE) + + if obj is not None: + self.memoize(obj) + + # More new special cases (that work with older protocols as + # well): when __reduce__ returns a tuple with 4 or 5 items, + # the 4th and 5th item should be iterators that provide list + # items and dict items (as (key, value) tuples), or None. + + if listitems is not None: + self._batch_appends(listitems) + + if dictitems is not None: + self._batch_setitems(dictitems) + + if state is not None: + #print 'obj %s has state %s' % (obj, state) + save(state) + write(pickle.BUILD) + + + def save_xrange(self, obj): + """Save an xrange object in python 2.5 + Python 2.6 supports this natively + """ + range_params = xrange_params(obj) + self.save_reduce(_build_xrange,range_params) + + #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it + try: + xrange(0).__reduce__() + except TypeError: #can't pickle -- use PiCloud pickler + dispatch[xrange] = save_xrange + + def save_partial(self, obj): + """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" + self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) + + if sys.version_info < (2,7): #2.7 supports partial pickling + dispatch[partial] = save_partial + + + def save_file(self, obj): + """Save a file""" + import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute + from ..transport.adapter import SerializingAdapter + + if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): + raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") + if obj.name == '': + return self.save_reduce(getattr, (sys,'stdout'), obj=obj) + if obj.name == '': + return self.save_reduce(getattr, (sys,'stderr'), obj=obj) + if obj.name == '': + raise pickle.PicklingError("Cannot pickle standard input") + if hasattr(obj, 'isatty') and obj.isatty(): + raise pickle.PicklingError("Cannot pickle files that map to tty objects") + if 'r' not in obj.mode: + raise pickle.PicklingError("Cannot pickle files that are not opened for reading") + name = obj.name + try: + fsize = os.stat(name).st_size + except OSError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) + + if obj.closed: + #create an empty closed string io + retval = pystringIO.StringIO("") + retval.close() + elif not fsize: #empty file + retval = pystringIO.StringIO("") + try: + tmpfile = file(name) + tst = tmpfile.read(1) + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + tmpfile.close() + if tst != '': + raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) + elif fsize > SerializingAdapter.max_transmit_data: + raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % + (name,SerializingAdapter.max_transmit_data)) + else: + try: + tmpfile = file(name) + contents = tmpfile.read(SerializingAdapter.max_transmit_data) + tmpfile.close() + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + retval = pystringIO.StringIO(contents) + curloc = obj.tell() + retval.seek(curloc) + + retval.name = name + self.save(retval) #save stringIO + self.memoize(obj) + + dispatch[file] = save_file + """Special functions for Add-on libraries""" + + def inject_numpy(self): + numpy = sys.modules.get('numpy') + if not numpy or not hasattr(numpy, 'ufunc'): + return + self.dispatch[numpy.ufunc] = self.__class__.save_ufunc + + numpy_tst_mods = ['numpy', 'scipy.special'] + def save_ufunc(self, obj): + """Hack function for saving numpy ufunc objects""" + name = obj.__name__ + for tst_mod_name in self.numpy_tst_mods: + tst_mod = sys.modules.get(tst_mod_name, None) + if tst_mod: + if name in tst_mod.__dict__: + self.save_reduce(_getobject, (tst_mod_name, name)) + return + raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj)) + + def inject_timeseries(self): + """Handle bugs with pickling scikits timeseries""" + tseries = sys.modules.get('scikits.timeseries.tseries') + if not tseries or not hasattr(tseries, 'Timeseries'): + return + self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries + + def save_timeseries(self, obj): + import scikits.timeseries.tseries as ts + + func, reduce_args, state = obj.__reduce__() + if func != ts._tsreconstruct: + raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func)) + state = (1, + obj.shape, + obj.dtype, + obj.flags.fnc, + obj._data.tostring(), + ts.getmaskarray(obj).tostring(), + obj._fill_value, + obj._dates.shape, + obj._dates.__array__().tostring(), + obj._dates.dtype, #added -- preserve type + obj.freq, + obj._optinfo, + ) + return self.save_reduce(_genTimeSeries, (reduce_args, state)) + + def inject_email(self): + """Block email LazyImporters from being saved""" + email = sys.modules.get('email') + if not email: + return + self.dispatch[email.LazyImporter] = self.__class__.save_unsupported + + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + self.inject_numpy() + self.inject_timeseries() + self.inject_email() + + """Python Imaging Library""" + def save_image(self, obj): + if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \ + and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()): + #if image not loaded yet -- lazy load + self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj) + else: + #image is loaded - just transmit it over + self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj) + + """ + def memoize(self, obj): + pickle.Pickler.memoize(self, obj) + if printMemoization: + print 'memoizing ' + str(obj) + """ + + + +# Shorthands for legacy support + +def dump(obj, file, protocol=2): + CloudPickler(file, protocol).dump(obj) + +def dumps(obj, protocol=2): + file = StringIO() + + cp = CloudPickler(file,protocol) + cp.dump(obj) + + #print 'cloud dumped', str(obj), str(cp.modules) + + return file.getvalue() + + +#hack for __import__ not working as desired +def subimport(name): + __import__(name) + return sys.modules[name] + +#hack to load django settings: +def django_settings_load(name): + modified_env = False + + if 'DJANGO_SETTINGS_MODULE' not in os.environ: + os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps + modified_env = True + try: + module = subimport(name) + except Exception, i: + print >> sys.stderr, 'Cloud not import django settings %s:' % (name) + print_exec(sys.stderr) + if modified_env: + del os.environ['DJANGO_SETTINGS_MODULE'] + else: + #add project directory to sys,path: + if hasattr(module,'__file__'): + dirname = os.path.split(module.__file__)[0] + '/' + sys.path.append(dirname) + +# restores function attributes +def _restore_attr(obj, attr): + for key, val in attr.items(): + setattr(obj, key, val) + return obj + +def _get_module_builtins(): + return pickle.__builtins__ + +def print_exec(stream): + ei = sys.exc_info() + traceback.print_exception(ei[0], ei[1], ei[2], None, stream) + +def _modules_to_main(modList): + """Force every module in modList to be placed into main""" + if not modList: + return + + main = sys.modules['__main__'] + for modname in modList: + if type(modname) is str: + try: + mod = __import__(modname) + except Exception, i: #catch all... + sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ +A version mismatch is likely. Specific error was:\n' % modname) + print_exec(sys.stderr) + else: + setattr(main,mod.__name__, mod) + else: + #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) + #In old version actual module was sent + setattr(main,modname.__name__, modname) + +#object generators: +def _build_xrange(start, step, len): + """Built xrange explicitly""" + return xrange(start, start + step*len, step) + +def _genpartial(func, args, kwds): + if not args: + args = () + if not kwds: + kwds = {} + return partial(func, *args, **kwds) + + +def _fill_function(func, globals, defaults, closure, dict): + """ Fills in the rest of function data into the skeleton function object + that were created via _make_skel_func(). + """ + func.func_globals.update(globals) + func.func_defaults = defaults + func.func_dict = dict + + if len(closure) != len(func.func_closure): + raise pickle.UnpicklingError("closure lengths don't match up") + for i in range(len(closure)): + _change_cell_value(func.func_closure[i], closure[i]) + + return func + +def _make_skel_func(code, num_closures, base_globals = None): + """ Creates a skeleton function object that contains just the provided + code and the correct number of cells in func_closure. All other + func attributes (e.g. func_globals) are empty. + """ + #build closure (cells): + if not ctypes: + raise Exception('ctypes failed to import; cannot build function') + + cellnew = ctypes.pythonapi.PyCell_New + cellnew.restype = ctypes.py_object + cellnew.argtypes = (ctypes.py_object,) + dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) + + if base_globals is None: + base_globals = {} + base_globals['__builtins__'] = __builtins__ + + return types.FunctionType(code, base_globals, + None, None, dummy_closure) + +# this piece of opaque code is needed below to modify 'cell' contents +cell_changer_code = new.code( + 1, 1, 2, 0, + ''.join([ + chr(dis.opmap['LOAD_FAST']), '\x00\x00', + chr(dis.opmap['DUP_TOP']), + chr(dis.opmap['STORE_DEREF']), '\x00\x00', + chr(dis.opmap['RETURN_VALUE']) + ]), + (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () +) + +def _change_cell_value(cell, newval): + """ Changes the contents of 'cell' object to newval """ + return new.function(cell_changer_code, {}, None, (), (cell,))(newval) + +"""Constructors for 3rd party libraries +Note: These can never be renamed due to client compatibility issues""" + +def _getobject(modname, attribute): + mod = __import__(modname) + return mod.__dict__[attribute] + +def _generateImage(size, mode, str_rep): + """Generate image from string representation""" + import Image + i = Image.new(mode, size) + i.fromstring(str_rep) + return i + +def _lazyloadImage(fp): + import Image + fp.seek(0) #works in almost any case + return Image.open(fp) + +"""Timeseries""" +def _genTimeSeries(reduce_args, state): + import scikits.timeseries.tseries as ts + from numpy import ndarray + from numpy.ma import MaskedArray + + + time_series = ts._tsreconstruct(*reduce_args) + + #from setstate modified + (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state + #print 'regenerating %s' % dtyp + + MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) + _dates = time_series._dates + #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ + ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) + _dates.freq = frq + _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, + toobj=None, toord=None, tostr=None)) + # Update the _optinfo dictionary + time_series._optinfo.update(infodict) + return time_series + diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index c892e86b93..5579c56de3 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,5 +1,5 @@ from base64 import standard_b64encode as b64enc -from cloud.serialization import cloudpickle +from pyspark import cloudpickle from itertools import chain from pyspark.serializers import PairSerializer, NopSerializer, \ diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 4d4cc939c3..4c4b02fce4 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -5,7 +5,7 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. -from cloud.serialization.cloudpickle import CloudPickler +from pyspark.cloudpickle import CloudPickler import cPickle diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt index d9b3fe40bd..71e2bc2b89 100644 --- a/pyspark/requirements.txt +++ b/pyspark/requirements.txt @@ -4,6 +4,3 @@ # install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. # git+git://github.com/bartdag/py4j.git@3dbf380d3d2cdeb9aab394454ea74d80c4aba1ea - -simplejson==2.6.1 -cloud==2.5.5 -- cgit v1.2.3 From fd94e5443c99775bfad1928729f5075c900ad0f9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Aug 2012 16:07:10 -0700 Subject: Use only cPickle for serialization in Python API. Objects serialized with JSON can be compared for equality, but JSON can be slow to serialize and only supports a limited range of data types. --- .../main/scala/spark/api/python/PythonRDD.scala | 192 +++++++--- pyspark/pyspark/context.py | 49 +-- pyspark/pyspark/java_gateway.py | 1 - pyspark/pyspark/join.py | 32 +- pyspark/pyspark/rdd.py | 414 ++++++++------------- pyspark/pyspark/serializers.py | 233 +----------- pyspark/pyspark/worker.py | 64 ++-- 7 files changed, 381 insertions(+), 604 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 660ad48afe..b9a0168d18 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,22 +1,26 @@ package spark.api.python -import java.io.PrintWriter +import java.io._ import scala.collection.Map import scala.collection.JavaConversions._ import scala.io.Source import spark._ -import api.java.{JavaPairRDD, JavaRDD} +import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import scala.{collection, Some} +import collection.parallel.mutable +import scala.collection import scala.Some trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], - command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[String]= { - val currentEnvVars = new ProcessBuilder().environment() - val SPARK_HOME = currentEnvVars.get("SPARK_HOME") + command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = { + val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) // Add the environmental variables to the process. + val currentEnvVars = pb.environment() + envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } @@ -41,33 +45,70 @@ trait PythonRDDBase { for (elem <- command) { out.println(elem) } + out.flush() + val dOut = new DataOutputStream(proc.getOutputStream) for (elem <- parent.iterator(split)) { - out.println(PythonRDD.pythonDump(elem)) + if (elem.isInstanceOf[Array[Byte]]) { + val arr = elem.asInstanceOf[Array[Byte]] + dOut.writeInt(arr.length) + dOut.write(arr) + } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { + val t = elem.asInstanceOf[scala.Tuple2[_, _]] + val t1 = t._1.asInstanceOf[Array[Byte]] + val t2 = t._2.asInstanceOf[Array[Byte]] + val length = t1.length + t2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(PythonRDD.stripPickle(t1)) + dOut.write(PythonRDD.stripPickle(t2)) + dOut.writeByte(Pickle.TUPLE2) + dOut.writeByte(Pickle.STOP) + } else if (elem.isInstanceOf[String]) { + // For uniformity, strings are wrapped into Pickles. + val s = elem.asInstanceOf[String].getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.writeByte(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + } else { + throw new Exception("Unexpected RDD type") + } } - out.close() + dOut.flush() + out.flush() + proc.getOutputStream.close() } }.start() // Return an iterator that read lines from the process's stdout - val lines: Iterator[String] = Source.fromInputStream(proc.getInputStream).getLines - wrapIterator(lines, proc) - } + val stream = new DataInputStream(proc.getInputStream) + return new Iterator[Array[Byte]] { + def next() = { + val obj = _nextObj + _nextObj = read() + obj + } - def wrapIterator[T](iter: Iterator[T], proc: Process): Iterator[T] = { - return new Iterator[T] { - def next() = iter.next() - - def hasNext = { - if (iter.hasNext) { - true - } else { - val exitStatus = proc.waitFor() - if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) - } - false + private def read() = { + try { + val length = stream.readInt() + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } catch { + case eof: EOFException => { new Array[Byte](0) } + case e => throw e } } + + var _nextObj = read() + + def hasNext = _nextObj.length != 0 } } } @@ -75,7 +116,7 @@ trait PythonRDDBase { class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], envVars: Map[String, String], preservePartitoning: Boolean, pythonExec: String) - extends RDD[String](parent.context) with PythonRDDBase { + extends RDD[Array[Byte]](parent.context) with PythonRDDBase { def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = this(parent, command, Map(), preservePartitoning, pythonExec) @@ -91,16 +132,16 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[String] = + override def compute(split: Split): Iterator[Array[Byte]] = compute(split, envVars, command, parent, pythonExec) - val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) + val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } class PythonPairRDD[T: ClassManifest] ( parent: RDD[T], command: Seq[String], envVars: Map[String, String], preservePartitoning: Boolean, pythonExec: String) - extends RDD[(String, String)](parent.context) with PythonRDDBase { + extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = this(parent, command, Map(), preservePartitoning, pythonExec) @@ -116,32 +157,95 @@ class PythonPairRDD[T: ClassManifest] ( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[(String, String)] = { + override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { compute(split, envVars, command, parent, pythonExec).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("Unexpected value: " + x) + case x => throw new Exception("PythonPairRDD: unexpected value: " + x) } } - val asJavaPairRDD : JavaPairRDD[String, String] = JavaPairRDD.fromRDD(this) + val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } + object PythonRDD { - def pythonDump[T](x: T): String = { - if (x.isInstanceOf[scala.Option[_]]) { - val t = x.asInstanceOf[scala.Option[_]] - t match { - case None => "*" - case Some(z) => pythonDump(z) - } - } else if (x.isInstanceOf[scala.Tuple2[_, _]]) { - val t = x.asInstanceOf[scala.Tuple2[_, _]] - "(" + pythonDump(t._1) + "," + pythonDump(t._2) + ")" - } else if (x.isInstanceOf[java.util.List[_]]) { - val objs = asScalaBuffer(x.asInstanceOf[java.util.List[_]]).map(pythonDump) - "[" + objs.mkString("|") + "]" + + /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ + def stripPickle(arr: Array[Byte]) : Array[Byte] = { + arr.slice(2, arr.length - 1) + } + + def asPickle(elem: Any) : Array[Byte] = { + val baos = new ByteArrayOutputStream(); + val dOut = new DataOutputStream(baos); + if (elem.isInstanceOf[Array[Byte]]) { + elem.asInstanceOf[Array[Byte]] + } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { + val t = elem.asInstanceOf[scala.Tuple2[_, _]] + val t1 = t._1.asInstanceOf[Array[Byte]] + val t2 = t._2.asInstanceOf[Array[Byte]] + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(PythonRDD.stripPickle(t1)) + dOut.write(PythonRDD.stripPickle(t2)) + dOut.writeByte(Pickle.TUPLE2) + dOut.writeByte(Pickle.STOP) + baos.toByteArray() + } else if (elem.isInstanceOf[String]) { + // For uniformity, strings are wrapped into Pickles. + val s = elem.asInstanceOf[String].getBytes("UTF-8") + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + baos.toByteArray() } else { - x.toString + throw new Exception("Unexpected RDD type") } } + + def pickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + JavaRDD[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + try { + while (true) { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + objs.append(obj) + } + } catch { + case eof: EOFException => {} + case e => throw e + } + JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + } + + def arrayAsPickle(arr : Any) : Array[Byte] = { + val pickles : Array[Byte] = arr.asInstanceOf[Array[Any]].map(asPickle).map(stripPickle).flatten + + Array[Byte](Pickle.PROTO, Pickle.TWO, Pickle.EMPTY_LIST, Pickle.MARK) ++ pickles ++ + Array[Byte] (Pickle.APPENDS, Pickle.STOP) + } +} + +private object Pickle { + def b(x: Int): Byte = x.asInstanceOf[Byte] + val PROTO: Byte = b(0x80) + val TWO: Byte = b(0x02) + val BINUNICODE : Byte = 'X' + val STOP : Byte = '.' + val TUPLE2 : Byte = b(0x86) + val EMPTY_LIST : Byte = ']' + val MARK : Byte = '(' + val APPENDS : Byte = 'e' +} +class ExtractValue extends spark.api.java.function.Function[(Array[Byte], + Array[Byte]), Array[Byte]] { + + override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 + } diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 587ab12b5f..ac7e4057e9 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -3,22 +3,24 @@ import atexit from tempfile import NamedTemporaryFile from pyspark.java_gateway import launch_gateway -from pyspark.serializers import JSONSerializer, NopSerializer -from pyspark.rdd import RDD, PairRDD +from pyspark.serializers import PickleSerializer, dumps +from pyspark.rdd import RDD class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - python_dump = jvm.spark.api.python.PythonRDD.pythonDump + pickleFile = jvm.spark.api.python.PythonRDD.pickleFile + asPickle = jvm.spark.api.python.PythonRDD.asPickle + arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultSerializer=JSONSerializer, - defaultParallelism=None, pythonExec='python'): + + def __init__(self, master, name, defaultParallelism=None, + pythonExec='python'): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) - self.defaultSerializer = defaultSerializer self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = pythonExec @@ -31,39 +33,26 @@ class SparkContext(object): self._jsc.stop() self._jsc = None - def parallelize(self, c, numSlices=None, serializer=None): - serializer = serializer or self.defaultSerializer - numSlices = numSlices or self.defaultParallelism - # Calling the Java parallelize() method with an ArrayList is too slow, - # because it sends O(n) Py4J commands. As an alternative, serialized - # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False) - tempFile.writelines(serializer.dumps(x) + '\n' for x in c) - tempFile.close() - atexit.register(lambda: os.unlink(tempFile.name)) - return self.textFile(tempFile.name, numSlices, serializer) - - def parallelizePairs(self, c, numSlices=None, keySerializer=None, - valSerializer=None): + def parallelize(self, c, numSlices=None): """ >>> sc = SparkContext("local", "test") - >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)]) + >>> rdd = sc.parallelize([(1, 2), (3, 4)]) >>> rdd.collect() [(1, 2), (3, 4)] """ - keySerializer = keySerializer or self.defaultSerializer - valSerializer = valSerializer or self.defaultSerializer numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) - for (k, v) in c: - tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n') - tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n') + for x in c: + dumps(PickleSerializer.dumps(x), tempFile) tempFile.close() atexit.register(lambda: os.unlink(tempFile.name)) - jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo") - return PairRDD(jrdd, self, keySerializer, valSerializer) + jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) + return RDD(jrdd, self) - def textFile(self, name, numSlices=None, serializer=NopSerializer): + def textFile(self, name, numSlices=None): numSlices = numSlices or self.defaultParallelism jrdd = self._jsc.textFile(name, numSlices) - return RDD(jrdd, self, serializer) + return RDD(jrdd, self) diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index 2df80aee85..bcb405ba72 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -16,5 +16,4 @@ def launch_gateway(): java_import(gateway.jvm, "spark.api.java.*") java_import(gateway.jvm, "spark.api.python.*") java_import(gateway.jvm, "scala.Tuple2") - java_import(gateway.jvm, "spark.api.python.PythonRDD.pythonDump") return gateway diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py index c67520fce8..7036c47980 100644 --- a/pyspark/pyspark/join.py +++ b/pyspark/pyspark/join.py @@ -30,15 +30,12 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ -from pyspark.serializers import PairSerializer, OptionSerializer, \ - ArraySerializer -def _do_python_join(rdd, other, numSplits, dispatch, valSerializer): - vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) - ws = other.mapPairs(lambda (k, v): (k, (2, v))) - return vs.union(ws).groupByKey(numSplits) \ - .flatMapValues(dispatch, valSerializer) +def _do_python_join(rdd, other, numSplits, dispatch): + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch) def python_join(rdd, other, numSplits): @@ -50,8 +47,7 @@ def python_join(rdd, other, numSplits): elif n == 2: wbuf.append(v) return [(v, w) for v in vbuf for w in wbuf] - valSerializer = PairSerializer(rdd.valSerializer, other.valSerializer) - return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch) def python_right_outer_join(rdd, other, numSplits): @@ -65,9 +61,7 @@ def python_right_outer_join(rdd, other, numSplits): if not vbuf: vbuf.append(None) return [(v, w) for v in vbuf for w in wbuf] - valSerializer = PairSerializer(OptionSerializer(rdd.valSerializer), - other.valSerializer) - return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch) def python_left_outer_join(rdd, other, numSplits): @@ -81,17 +75,12 @@ def python_left_outer_join(rdd, other, numSplits): if not wbuf: wbuf.append(None) return [(v, w) for v in vbuf for w in wbuf] - valSerializer = PairSerializer(rdd.valSerializer, - OptionSerializer(other.valSerializer)) - return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch) def python_cogroup(rdd, other, numSplits): - resultValSerializer = PairSerializer( - ArraySerializer(rdd.valSerializer), - ArraySerializer(other.valSerializer)) - vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) - ws = other.mapPairs(lambda (k, v): (k, (2, v))) + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) def dispatch(seq): vbuf, wbuf = [], [] for (n, v) in seq: @@ -100,5 +89,4 @@ def python_cogroup(rdd, other, numSplits): elif n == 2: wbuf.append(v) return (vbuf, wbuf) - return vs.union(ws).groupByKey(numSplits) \ - .mapValues(dispatch, resultValSerializer) + return vs.union(ws).groupByKey(numSplits).mapValues(dispatch) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 5579c56de3..8eccddc0a2 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,31 +1,17 @@ from base64 import standard_b64encode as b64enc -from pyspark import cloudpickle -from itertools import chain -from pyspark.serializers import PairSerializer, NopSerializer, \ - OptionSerializer, ArraySerializer +from pyspark import cloudpickle +from pyspark.serializers import PickleSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup class RDD(object): - def __init__(self, jrdd, ctx, serializer=None): + def __init__(self, jrdd, ctx): self._jrdd = jrdd self.is_cached = False self.ctx = ctx - self.serializer = serializer or ctx.defaultSerializer - - def _builder(self, jrdd, ctx): - return RDD(jrdd, ctx, self.serializer) - - @property - def id(self): - return self._jrdd.id() - - @property - def splits(self): - return self._jrdd.splits() @classmethod def _get_pipe_command(cls, command, functions): @@ -41,55 +27,18 @@ class RDD(object): self._jrdd.cache() return self - def map(self, f, serializer=None, preservesPartitioning=False): - return MappedRDD(self, f, serializer, preservesPartitioning) - - def mapPairs(self, f, keySerializer=None, valSerializer=None, - preservesPartitioning=False): - return PairMappedRDD(self, f, keySerializer, valSerializer, - preservesPartitioning) + def map(self, f, preservesPartitioning=False): + return MappedRDD(self, f, preservesPartitioning) - def flatMap(self, f, serializer=None): + def flatMap(self, f): """ >>> rdd = sc.parallelize([2, 3, 4]) >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) [1, 1, 1, 2, 2, 3] - """ - serializer = serializer or self.ctx.defaultSerializer - dumps = serializer.dumps - loads = self.serializer.loads - def func(x): - pickled_elems = (dumps(y) for y in f(loads(x))) - return "\n".join(pickled_elems) or None - pipe_command = RDD._get_pipe_command("map", [func]) - class_manifest = self._jrdd.classManifest() - jrdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, - class_manifest).asJavaRDD() - return RDD(jrdd, self.ctx, serializer) - - def flatMapPairs(self, f, keySerializer=None, valSerializer=None, - preservesPartitioning=False): - """ - >>> rdd = sc.parallelize([2, 3, 4]) - >>> sorted(rdd.flatMapPairs(lambda x: [(x, x), (x, x)]).collect()) + >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - keySerializer = keySerializer or self.ctx.defaultSerializer - valSerializer = valSerializer or self.ctx.defaultSerializer - dumpk = keySerializer.dumps - dumpv = valSerializer.dumps - loads = self.serializer.loads - def func(x): - pairs = f(loads(x)) - pickled_pairs = ((dumpk(k), dumpv(v)) for (k, v) in pairs) - return "\n".join(chain.from_iterable(pickled_pairs)) or None - pipe_command = RDD._get_pipe_command("map", [func]) - class_manifest = self._jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, - preservesPartitioning, self.ctx.pythonExec, class_manifest) - return PairRDD(python_rdd.asJavaPairRDD(), self.ctx, keySerializer, - valSerializer) + return MappedRDD(self, f, preservesPartitioning=False, command='flatmap') def filter(self, f): """ @@ -97,9 +46,8 @@ class RDD(object): >>> rdd.filter(lambda x: x % 2 == 0).collect() [2, 4] """ - loads = self.serializer.loads - def filter_func(x): return x if f(loads(x)) else None - return self._builder(self._pipe(filter_func), self.ctx) + def filter_func(x): return x if f(x) else None + return RDD(self._pipe(filter_func), self.ctx) def _pipe(self, functions, command="map"): class_manifest = self._jrdd.classManifest() @@ -108,32 +56,22 @@ class RDD(object): False, self.ctx.pythonExec, class_manifest) return python_rdd.asJavaRDD() - def _pipePairs(self, functions, command="mapPairs", - preservesPartitioning=False): - class_manifest = self._jrdd.classManifest() - pipe_command = RDD._get_pipe_command(command, functions) - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, - preservesPartitioning, self.ctx.pythonExec, class_manifest) - return python_rdd.asJavaPairRDD() - def distinct(self): """ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) [1, 2, 3] """ - if self.serializer.is_comparable: - return self._builder(self._jrdd.distinct(), self.ctx) - return self.mapPairs(lambda x: (x, "")) \ + return self.map(lambda x: (x, "")) \ .reduceByKey(lambda x, _: x) \ .map(lambda (x, _): x) def sample(self, withReplacement, fraction, seed): jrdd = self._jrdd.sample(withReplacement, fraction, seed) - return self._builder(jrdd, self.ctx) + return RDD(jrdd, self.ctx) def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + return [PickleSerializer.loads(x) for x in vals] def union(self, other): """ @@ -141,7 +79,7 @@ class RDD(object): >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] """ - return self._builder(self._jrdd.union(other._jrdd), self.ctx) + return RDD(self._jrdd.union(other._jrdd), self.ctx) # TODO: sort @@ -155,16 +93,17 @@ class RDD(object): >>> sorted(rdd.cartesian(rdd).collect()) [(1, 1), (1, 2), (2, 1), (2, 2)] """ - return PairRDD(self._jrdd.cartesian(other._jrdd), self.ctx) + return RDD(self._jrdd.cartesian(other._jrdd), self.ctx) # numsplits def groupBy(self, f, numSplits=None): """ >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) - >>> sorted(rdd.groupBy(lambda x: x % 2).collect()) + >>> result = rdd.groupBy(lambda x: x % 2).collect() + >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.mapPairs(lambda x: (f(x), x)).groupByKey(numSplits) + return self.map(lambda x: (f(x), x)).groupByKey(numSplits) # TODO: pipe @@ -178,25 +117,19 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - vals = self._jrdd.collect() - return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) + return PickleSerializer.loads(bytes(pickle)) - def reduce(self, f, serializer=None): + def reduce(self, f): """ - >>> import operator - >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(operator.add) + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) 15 + >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) + 10 """ - serializer = serializer or self.ctx.defaultSerializer - loads = self.serializer.loads - dumps = serializer.dumps - def reduceFunction(x, acc): - if acc is None: - return loads(x) - else: - return f(loads(x), acc) - vals = self._pipe([reduceFunction, dumps], command="reduce").collect() - return reduce(f, (serializer.loads(x) for x in vals)) + vals = MappedRDD(self, f, command="reduce", preservesPartitioning=False).collect() + return reduce(f, vals) # TODO: fold @@ -216,36 +149,35 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).take(2) [2, 3] """ - vals = self._jrdd.take(num) - return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) + return PickleSerializer.loads(bytes(pickle)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return self.serializer.loads(self.ctx.python_dump(self._jrdd.first())) + return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first()))) # TODO: saveAsTextFile # TODO: saveAsObjectFile + # Pair functions -class PairRDD(RDD): - - def __init__(self, jrdd, ctx, keySerializer=None, valSerializer=None): - RDD.__init__(self, jrdd, ctx) - self.keySerializer = keySerializer or ctx.defaultSerializer - self.valSerializer = valSerializer or ctx.defaultSerializer - self.serializer = \ - PairSerializer(self.keySerializer, self.valSerializer) - - def _builder(self, jrdd, ctx): - return PairRDD(jrdd, ctx, self.keySerializer, self.valSerializer) + def collectAsMap(self): + """ + >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + return dict(self.collect()) def reduceByKey(self, func, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(x.reduceByKey(lambda a, b: a + b).collect()) [('a', 2), ('b', 1)] """ @@ -259,90 +191,67 @@ class PairRDD(RDD): def join(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2), ("a", 3)]) - >>> x.join(y).collect() + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2), ("a", 3)]) + >>> sorted(x.join(y).collect()) [('a', (1, 2)), ('a', (1, 3))] - - Check that we get a PairRDD-like object back: - >>> assert x.join(y).join """ - assert self.keySerializer.name == other.keySerializer.name - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.join(other._jrdd), - self.ctx, self.keySerializer, - PairSerializer(self.valSerializer, other.valSerializer)) - else: - return python_join(self, other, numSplits) + return python_join(self, other, numSplits) def leftOuterJoin(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) >>> sorted(x.leftOuterJoin(y).collect()) [('a', (1, 2)), ('b', (4, None))] """ - assert self.keySerializer.name == other.keySerializer.name - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.leftOuterJoin(other._jrdd), - self.ctx, self.keySerializer, - PairSerializer(self.valSerializer, - OptionSerializer(other.valSerializer))) - else: - return python_left_outer_join(self, other, numSplits) + return python_left_outer_join(self, other, numSplits) def rightOuterJoin(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) >>> sorted(y.rightOuterJoin(x).collect()) [('a', (2, 1)), ('b', (None, 4))] """ - assert self.keySerializer.name == other.keySerializer.name - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.rightOuterJoin(other._jrdd), - self.ctx, self.keySerializer, - PairSerializer(OptionSerializer(self.valSerializer), - other.valSerializer)) - else: - return python_right_outer_join(self, other, numSplits) + return python_right_outer_join(self, other, numSplits) + + # TODO: pipelining + # TODO: optimizations + def shuffle(self, numSplits): + if numSplits is None: + numSplits = self.ctx.defaultParallelism + pipe_command = RDD._get_pipe_command('shuffle_map_step', []) + class_manifest = self._jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), + pipe_command, False, self.ctx.pythonExec, class_manifest) + partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) + jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) + # TODO: extract second value. + return RDD(jrdd, self.ctx) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numSplits=None, serializer=None): + numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> def f(x): return x >>> def add(a, b): return a + str(b) >>> sorted(x.combineByKey(str, add, add).collect()) [('a', '11'), ('b', '1')] """ - serializer = serializer or self.ctx.defaultSerializer if numSplits is None: numSplits = self.ctx.defaultParallelism - # Use hash() to create keys that are comparable in Java. - loadkv = self.serializer.loads - def pairify(kv): - # TODO: add method to deserialize only the key or value from - # a PairSerializer? - key = loadkv(kv)[0] - return (str(hash(key)), kv) - partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) - jrdd = self._pipePairs(pairify).partitionBy(partitioner) - pairified = PairRDD(jrdd, self.ctx, NopSerializer, self.serializer) - - loads = PairSerializer(NopSerializer, self.serializer).loads - dumpk = self.keySerializer.dumps - dumpc = serializer.dumps - - functions = [createCombiner, mergeValue, mergeCombiners, loads, dumpk, - dumpc] - jpairs = pairified._pipePairs(functions, "combine_by_key", - preservesPartitioning=True) - return PairRDD(jpairs, self.ctx, self.keySerializer, serializer) + shuffled = self.shuffle(numSplits) + functions = [createCombiner, mergeValue, mergeCombiners] + jpairs = shuffled._pipe(functions, "combine_by_key") + return RDD(jpairs, self.ctx) def groupByKey(self, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(x.groupByKey().collect()) [('a', [1, 1]), ('b', [1])] """ @@ -360,29 +269,15 @@ class PairRDD(RDD): return self.combineByKey(createCombiner, mergeValue, mergeCombiners, numSplits) - def collectAsMap(self): - """ - >>> m = sc.parallelizePairs([(1, 2), (3, 4)]).collectAsMap() - >>> m[1] - 2 - >>> m[3] - 4 - """ - m = self._jrdd.collectAsMap() - def loads(x): - (k, v) = x - return (self.keySerializer.loads(k), self.valSerializer.loads(v)) - return dict(loads(x) for x in m.items()) - - def flatMapValues(self, f, valSerializer=None): + def flatMapValues(self, f): flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) - return self.flatMapPairs(flat_map_fn, self.keySerializer, - valSerializer, True) + return self.flatMap(flat_map_fn) - def mapValues(self, f, valSerializer=None): + def mapValues(self, f): map_values_fn = lambda (k, v): (k, f(v)) - return self.mapPairs(map_values_fn, self.keySerializer, valSerializer, - True) + return self.map(map_values_fn, preservesPartitioning=True) + + # TODO: implement shuffle. # TODO: support varargs cogroup of several RDDs. def groupWith(self, other): @@ -390,20 +285,12 @@ class PairRDD(RDD): def cogroup(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) >>> x.cogroup(y).collect() [('a', ([1], [2])), ('b', ([4], []))] """ - assert self.keySerializer.name == other.keySerializer.name - resultValSerializer = PairSerializer( - ArraySerializer(self.valSerializer), - ArraySerializer(other.valSerializer)) - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.cogroup(other._jrdd), - self.ctx, self.keySerializer, resultValSerializer) - else: - return python_cogroup(self, other, numSplits) + return python_cogroup(self, other, numSplits) # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the @@ -413,44 +300,84 @@ class PairRDD(RDD): # TODO: file saving -class MappedRDDBase(object): - def __init__(self, prev, func, serializer, preservesPartitioning=False): - if isinstance(prev, MappedRDDBase) and not prev.is_cached: +class MappedRDD(RDD): + """ + Pipelined maps: + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + + Pipelined reduces: + >>> from operator import add + >>> rdd.map(lambda x: 2 * x).reduce(add) + 20 + >>> rdd.flatMap(lambda x: [x, x]).reduce(add) + 20 + """ + def __init__(self, prev, func, preservesPartitioning=False, command='map'): + if isinstance(prev, MappedRDD) and not prev.is_cached: prev_func = prev.func - self.func = lambda x: func(prev_func(x)) + if command == 'reduce': + if prev.command == 'flatmap': + def flatmap_reduce_func(x, acc): + values = prev_func(x) + if values is None: + return acc + if not acc: + if len(values) == 1: + return values[0] + else: + return reduce(func, values[1:], values[0]) + else: + return reduce(func, values, acc) + self.func = flatmap_reduce_func + else: + def reduce_func(x, acc): + val = prev_func(x) + if not val: + return acc + if acc is None: + return val + else: + return func(val, acc) + self.func = reduce_func + else: + if prev.command == 'flatmap': + command = 'flatmap' + self.func = lambda x: (func(y) for y in prev_func(x)) + else: + self.func = lambda x: func(prev_func(x)) + self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning self._prev_jrdd = prev._prev_jrdd - self._prev_serializer = prev._prev_serializer + self.is_pipelined = True else: - self.func = func + if command == 'reduce': + def reduce_func(val, acc): + if acc is None: + return val + else: + return func(val, acc) + self.func = reduce_func + else: + self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd - self._prev_serializer = prev.serializer - self.serializer = serializer or prev.ctx.defaultSerializer + self.is_pipelined = False self.is_cached = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None - - -class MappedRDD(MappedRDDBase, RDD): - """ - >>> rdd = sc.parallelize([1, 2, 3, 4]) - >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() - [4, 8, 12, 16] - >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() - [4, 8, 12, 16] - """ + self.command = command @property def _jrdd(self): if not self._jrdd_val: - udf = self.func - loads = self._prev_serializer.loads - dumps = self.serializer.dumps - func = lambda x: dumps(udf(loads(x))) - pipe_command = RDD._get_pipe_command("map", [func]) + funcs = [self.func] + pipe_command = RDD._get_pipe_command(self.command, funcs) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, @@ -459,56 +386,11 @@ class MappedRDD(MappedRDDBase, RDD): return self._jrdd_val -class PairMappedRDD(MappedRDDBase, PairRDD): - """ - >>> rdd = sc.parallelize([1, 2, 3, 4]) - >>> rdd.mapPairs(lambda x: (x, x)) \\ - ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ - ... .collect() - [(2, 2), (4, 4), (6, 6), (8, 8)] - >>> rdd.mapPairs(lambda x: (x, x)) \\ - ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ - ... .map(lambda (x, _): x).collect() - [2, 4, 6, 8] - """ - - def __init__(self, prev, func, keySerializer=None, valSerializer=None, - preservesPartitioning=False): - self.keySerializer = keySerializer or prev.ctx.defaultSerializer - self.valSerializer = valSerializer or prev.ctx.defaultSerializer - serializer = PairSerializer(self.keySerializer, self.valSerializer) - MappedRDDBase.__init__(self, prev, func, serializer, - preservesPartitioning) - - @property - def _jrdd(self): - if not self._jrdd_val: - udf = self.func - loads = self._prev_serializer.loads - dumpk = self.keySerializer.dumps - dumpv = self.valSerializer.dumps - def func(x): - (k, v) = udf(loads(x)) - return (dumpk(k), dumpv(v)) - pipe_command = RDD._get_pipe_command("mapPairs", [func]) - class_manifest = self._prev_jrdd.classManifest() - self._jrdd_val = self.ctx.jvm.PythonPairRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - class_manifest).asJavaPairRDD() - return self._jrdd_val - - def _test(): import doctest from pyspark.context import SparkContext - from pyspark.serializers import PickleSerializer, JSONSerializer globs = globals().copy() - globs['sc'] = SparkContext('local', 'PythonTest', - defaultSerializer=JSONSerializer) - doctest.testmod(globs=globs) - globs['sc'].stop() - globs['sc'] = SparkContext('local', 'PythonTest', - defaultSerializer=PickleSerializer) + globs['sc'] = SparkContext('local', 'PythonTest') doctest.testmod(globs=globs) globs['sc'].stop() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index b113f5656b..7b3e6966e1 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -2,228 +2,35 @@ Data serialization methods. The Spark Python API is built on top of the Spark Java API. RDDs created in -Python are stored in Java as RDDs of Strings. Python objects are automatically -serialized/deserialized, so this representation is transparent to the end-user. - ------------------- -Serializer objects ------------------- - -`Serializer` objects are used to customize how an RDD's values are serialized. - -Each `Serializer` is a named tuple with four fields: - - - A `dumps` function, for serializing a Python object to a string. - - - A `loads` function, for deserializing a Python object from a string. - - - An `is_comparable` field, True if equal Python objects are serialized to - equal strings, and False otherwise. - - - A `name` field, used to identify the Serializer. Serializers are - compared for equality by comparing their names. - -The serializer's output should be base64-encoded. - ------------------------------------------------------------------- -`is_comparable`: comparing serialized representations for equality ------------------------------------------------------------------- - -If `is_comparable` is False, the serializer's representations of equal objects -are not required to be equal: - ->>> import pickle ->>> a = {1: 0, 9: 0} ->>> b = {9: 0, 1: 0} ->>> a == b -True ->>> pickle.dumps(a) == pickle.dumps(b) -False - -RDDs with comparable serializers can use native Java implementations of -operations like join() and distinct(), which may lead to better performance by -eliminating deserialization and Python comparisons. - -The default JSONSerializer produces comparable representations of common Python -data structures. - --------------------------------------- -Examples of serialized representations --------------------------------------- - -The RDD transformations that use Python UDFs are implemented in terms of -a modified `PipedRDD.pipe()` function. For each record `x` in the RDD, the -`pipe()` function pipes `x.toString()` to a Python worker process, which -deserializes the string into a Python object, executes user-defined functions, -and outputs serialized Python objects. - -The regular `toString()` method returns an ambiguous representation, due to the -way that Scala `Option` instances are printed: - ->>> from context import SparkContext ->>> sc = SparkContext("local", "SerializerDocs") ->>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) ->>> y = sc.parallelizePairs([("a", 2)]) - ->>> print y.rightOuterJoin(x)._jrdd.first().toString() -(ImEi,(Some(Mg==),MQ==)) - -In Java, preprocessing is performed to handle Option instances, so the Python -process receives unambiguous input: - ->>> print sc.python_dump(y.rightOuterJoin(x)._jrdd.first()) -(ImEi,(Mg==,MQ==)) - -The base64-encoding eliminates the need to escape newlines, parentheses and -other special characters. - ----------------------- -Serializer composition ----------------------- - -In order to handle nested structures, which could contain object serialized -with different serializers, the RDD module composes serializers. For example, -the serializers in the previous example are: - ->>> print x.serializer.name -PairSerializer - ->>> print y.serializer.name -PairSerializer - ->>> print y.rightOuterJoin(x).serializer.name -PairSerializer, JSONSerializer>> +Python are stored in Java as RDD[Array[Byte]]. Python objects are +automatically serialized/deserialized, so this representation is transparent to +the end-user. """ -from base64 import standard_b64encode, standard_b64decode from collections import namedtuple import cPickle -import simplejson - - -Serializer = namedtuple("Serializer", - ["dumps","loads", "is_comparable", "name"]) - - -NopSerializer = Serializer(str, str, True, "NopSerializer") +import struct -JSONSerializer = Serializer( - lambda obj: standard_b64encode(simplejson.dumps(obj, sort_keys=True, - separators=(',', ':'))), - lambda s: simplejson.loads(standard_b64decode(s)), - True, - "JSONSerializer" -) +Serializer = namedtuple("Serializer", ["dumps","loads"]) PickleSerializer = Serializer( - lambda obj: standard_b64encode(cPickle.dumps(obj)), - lambda s: cPickle.loads(standard_b64decode(s)), - False, - "PickleSerializer" -) - - -def OptionSerializer(serializer): - """ - >>> ser = OptionSerializer(NopSerializer) - >>> ser.loads(ser.dumps("Hello, World!")) - 'Hello, World!' - >>> ser.loads(ser.dumps(None)) is None - True - """ - none_placeholder = '*' - - def dumps(x): - if x is None: - return none_placeholder - else: - return serializer.dumps(x) - - def loads(x): - if x == none_placeholder: - return None - else: - return serializer.loads(x) - - name = "OptionSerializer<%s>" % serializer.name - return Serializer(dumps, loads, serializer.is_comparable, name) - - -def PairSerializer(keySerializer, valSerializer): - """ - Returns a Serializer for a (key, value) pair. - - >>> ser = PairSerializer(JSONSerializer, JSONSerializer) - >>> ser.loads(ser.dumps((1, 2))) - (1, 2) - - >>> ser = PairSerializer(JSONSerializer, ser) - >>> ser.loads(ser.dumps((1, (2, 3)))) - (1, (2, 3)) - """ - def loads(kv): - try: - (key, val) = kv[1:-1].split(',', 1) - key = keySerializer.loads(key) - val = valSerializer.loads(val) - return (key, val) - except: - print "Error in deserializing pair from '%s'" % str(kv) - raise - - def dumps(kv): - (key, val) = kv - return"(%s,%s)" % (keySerializer.dumps(key), valSerializer.dumps(val)) - is_comparable = \ - keySerializer.is_comparable and valSerializer.is_comparable - name = "PairSerializer<%s, %s>" % (keySerializer.name, valSerializer.name) - return Serializer(dumps, loads, is_comparable, name) - - -def ArraySerializer(serializer): - """ - >>> ser = ArraySerializer(JSONSerializer) - >>> ser.loads(ser.dumps([1, 2, 3, 4])) - [1, 2, 3, 4] - >>> ser = ArraySerializer(PairSerializer(JSONSerializer, PickleSerializer)) - >>> ser.loads(ser.dumps([('a', 1), ('b', 2)])) - [('a', 1), ('b', 2)] - >>> ser.loads(ser.dumps([('a', 1)])) - [('a', 1)] - >>> ser.loads(ser.dumps([])) - [] - """ - def dumps(arr): - if arr == []: - return '[]' - else: - return '[' + '|'.join(serializer.dumps(x) for x in arr) + ']' - - def loads(s): - if s == '[]': - return [] - items = s[1:-1] - if '|' in items: - items = items.split('|') - else: - items = [items] - return [serializer.loads(x) for x in items] - - name = "ArraySerializer<%s>" % serializer.name - return Serializer(dumps, loads, serializer.is_comparable, name) - - -# TODO: IntegerSerializer - - -# TODO: DoubleSerializer + lambda obj: cPickle.dumps(obj, -1), + cPickle.loads) -def _test(): - import doctest - doctest.testmod() +def dumps(obj, stream): + # TODO: determining the length of non-byte objects. + stream.write(struct.pack("!i", len(obj))) + stream.write(obj) -if __name__ == "__main__": - _test() +def loads(stream): + length = stream.read(4) + if length == "": + raise EOFError + length = struct.unpack("!i", length)[0] + obj = stream.read(length) + if obj == "": + raise EOFError + return obj diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 4c4b02fce4..21ff84fb17 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -6,9 +6,9 @@ from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import dumps, loads, PickleSerializer import cPickle - # Redirect stdout to stderr so that users must return values from functions. old_stdout = sys.stdout sys.stdout = sys.stderr @@ -19,58 +19,64 @@ def load_function(): def output(x): - for line in x.split("\n"): - old_stdout.write(line.rstrip("\r\n") + "\n") + dumps(x, old_stdout) def read_input(): - for line in sys.stdin: - yield line.rstrip("\r\n") - + try: + while True: + yield loads(sys.stdin) + except EOFError: + return def do_combine_by_key(): create_combiner = load_function() merge_value = load_function() merge_combiners = load_function() # TODO: not used. - depickler = load_function() - key_pickler = load_function() - combiner_pickler = load_function() combiners = {} - for line in read_input(): - # Discard the hashcode added in the Python combineByKey() method. - (key, value) = depickler(line)[1] + for obj in read_input(): + (key, value) = PickleSerializer.loads(obj) if key not in combiners: combiners[key] = create_combiner(value) else: combiners[key] = merge_value(combiners[key], value) for (key, combiner) in combiners.iteritems(): - output(key_pickler(key)) - output(combiner_pickler(combiner)) + output(PickleSerializer.dumps((key, combiner))) -def do_map(map_pairs=False): +def do_map(flat=False): f = load_function() - for line in read_input(): + for obj in read_input(): try: - out = f(line) + #from pickletools import dis + #print repr(obj) + #print dis(obj) + out = f(PickleSerializer.loads(obj)) if out is not None: - if map_pairs: + if flat: for x in out: - output(x) + output(PickleSerializer.dumps(x)) else: - output(out) + output(PickleSerializer.dumps(out)) except: - sys.stderr.write("Error processing line '%s'\n" % line) + sys.stderr.write("Error processing obj %s\n" % repr(obj)) raise +def do_shuffle_map_step(): + for obj in read_input(): + key = PickleSerializer.loads(obj)[1] + output(str(hash(key))) + output(obj) + + def do_reduce(): f = load_function() - dumps = load_function() acc = None - for line in read_input(): - acc = f(line, acc) - output(dumps(acc)) + for obj in read_input(): + acc = f(PickleSerializer.loads(obj), acc) + if acc is not None: + output(PickleSerializer.dumps(acc)) def do_echo(): @@ -80,13 +86,15 @@ def do_echo(): def main(): command = sys.stdin.readline().strip() if command == "map": - do_map(map_pairs=False) - elif command == "mapPairs": - do_map(map_pairs=True) + do_map(flat=False) + elif command == "flatmap": + do_map(flat=True) elif command == "combine_by_key": do_combine_by_key() elif command == "reduce": do_reduce() + elif command == "shuffle_map_step": + do_shuffle_map_step() elif command == "echo": do_echo() else: -- cgit v1.2.3 From 607b53abfca049e7d9139e2d29893a3bb252de19 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 22 Aug 2012 00:43:55 -0700 Subject: Use numpy in Python k-means example. --- .../main/scala/spark/api/python/PythonRDD.scala | 8 +++++++- pyspark/pyspark/examples/kmeans.py | 23 ++++++++-------------- pyspark/pyspark/rdd.py | 9 +++------ pyspark/pyspark/worker.py | 8 +++----- 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index b9a0168d18..93847e2f14 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -101,7 +101,13 @@ trait PythonRDDBase { stream.readFully(obj) obj } catch { - case eof: EOFException => { new Array[Byte](0) } + case eof: EOFException => { + val exitStatus = proc.waitFor() + if (exitStatus != 0) { + throw new Exception("Subprocess exited with status " + exitStatus) + } + new Array[Byte](0) + } case e => throw e } } diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/pyspark/examples/kmeans.py index 0761d6e395..9cc366f03c 100644 --- a/pyspark/pyspark/examples/kmeans.py +++ b/pyspark/pyspark/examples/kmeans.py @@ -1,25 +1,18 @@ import sys from pyspark.context import SparkContext +from numpy import array, sum as np_sum def parseVector(line): - return [float(x) for x in line.split(' ')] - - -def addVec(x, y): - return [a + b for (a, b) in zip(x, y)] - - -def squaredDist(x, y): - return sum((a - b) ** 2 for (a, b) in zip(x, y)) + return array([float(x) for x in line.split(' ')]) def closestPoint(p, centers): bestIndex = 0 closest = float("+inf") for i in range(len(centers)): - tempDist = squaredDist(p, centers[i]) + tempDist = np_sum((p - centers[i]) ** 2) if tempDist < closest: closest = tempDist bestIndex = i @@ -41,14 +34,14 @@ if __name__ == "__main__": tempDist = 1.0 while tempDist > convergeDist: - closest = data.mapPairs( + closest = data.map( lambda p : (closestPoint(p, kPoints), (p, 1))) pointStats = closest.reduceByKey( - lambda (x1, y1), (x2, y2): (addVec(x1, x2), y1 + y2)) - newPoints = pointStats.mapPairs( - lambda (x, (y, z)): (x, [a / z for a in y])).collect() + lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) + newPoints = pointStats.map( + lambda (x, (y, z)): (x, y / z)).collect() - tempDist = sum(squaredDist(kPoints[x], y) for (x, y) in newPoints) + tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) for (x, y) in newPoints: kPoints[x] = y diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 8eccddc0a2..ff9c483032 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -71,7 +71,7 @@ class RDD(object): def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [PickleSerializer.loads(x) for x in vals] + return [PickleSerializer.loads(bytes(x)) for x in vals] def union(self, other): """ @@ -218,17 +218,16 @@ class RDD(object): # TODO: pipelining # TODO: optimizations - def shuffle(self, numSplits): + def shuffle(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - pipe_command = RDD._get_pipe_command('shuffle_map_step', []) + pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc]) class_manifest = self._jrdd.classManifest() python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, False, self.ctx.pythonExec, class_manifest) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - # TODO: extract second value. return RDD(jrdd, self.ctx) @@ -277,8 +276,6 @@ class RDD(object): map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) - # TODO: implement shuffle. - # TODO: support varargs cogroup of several RDDs. def groupWith(self, other): return self.cogroup(other) diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 21ff84fb17..b13ed5699a 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -48,9 +48,6 @@ def do_map(flat=False): f = load_function() for obj in read_input(): try: - #from pickletools import dis - #print repr(obj) - #print dis(obj) out = f(PickleSerializer.loads(obj)) if out is not None: if flat: @@ -64,9 +61,10 @@ def do_map(flat=False): def do_shuffle_map_step(): + hashFunc = load_function() for obj in read_input(): - key = PickleSerializer.loads(obj)[1] - output(str(hash(key))) + key = PickleSerializer.loads(obj)[0] + output(str(hashFunc(key))) output(obj) -- cgit v1.2.3 From 091b1438f53b810fa366c3e913174d380865cd0c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 24 Aug 2012 16:43:59 -0700 Subject: Fix WordCount job name --- streaming/src/main/scala/spark/streaming/examples/WordCount.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala index ba7bc63d6a..6c53007145 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala @@ -11,7 +11,7 @@ object WordCount { } // Create the context and set the batch size - val ssc = new SparkStreamContext(args(0), "ExampleTwo") + val ssc = new SparkStreamContext(args(0), "WordCount") ssc.setBatchDuration(Seconds(2)) // Create the FileInputDStream on the directory and use the -- cgit v1.2.3 From e7a5cbb543bcf23517ba01040ff74cab0153d678 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 24 Aug 2012 16:45:01 -0700 Subject: Reduce log4j verbosity for streaming --- streaming/src/test/resources/log4j.properties | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 streaming/src/test/resources/log4j.properties diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties new file mode 100644 index 0000000000..02fe16866e --- /dev/null +++ b/streaming/src/test/resources/log4j.properties @@ -0,0 +1,8 @@ +# Set everything to be logged to the console +log4j.rootCategory=WARN, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN -- cgit v1.2.3 From 4b523004877cf94152225484de7683e9d17cdb56 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Aug 2012 15:54:15 -0700 Subject: Fix options parsing in Python pi example. --- pyspark/pyspark/examples/pi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py index ad77694c41..fe63d2c952 100644 --- a/pyspark/pyspark/examples/pi.py +++ b/pyspark/pyspark/examples/pi.py @@ -10,7 +10,7 @@ if __name__ == "__main__": "Usage: PythonPi []" exit(-1) sc = SparkContext(sys.argv[1], "PythonKMeans") - slices = sys.argv[2] if len(sys.argv) > 2 else 2 + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 n = 100000 * slices def f(_): x = random() * 2 - 1 -- cgit v1.2.3 From f3b852ce66d193e3421eeecef71ea27bff73a94b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Aug 2012 19:38:50 -0700 Subject: Refactor Python MappedRDD to use iterator pipelines. --- pyspark/pyspark/rdd.py | 83 +++++++++++++++-------------------------------- pyspark/pyspark/worker.py | 55 +++++++++---------------------- 2 files changed, 41 insertions(+), 97 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index ff9c483032..7d280d8844 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,4 +1,5 @@ from base64 import standard_b64encode as b64enc +from itertools import chain, ifilter, imap from pyspark import cloudpickle from pyspark.serializers import PickleSerializer @@ -15,8 +16,6 @@ class RDD(object): @classmethod def _get_pipe_command(cls, command, functions): - if functions and not isinstance(functions, (list, tuple)): - functions = [functions] worker_args = [command] for f in functions: worker_args.append(b64enc(cloudpickle.dumps(f))) @@ -28,7 +27,8 @@ class RDD(object): return self def map(self, f, preservesPartitioning=False): - return MappedRDD(self, f, preservesPartitioning) + def func(iterator): return imap(f, iterator) + return PipelinedRDD(self, func, preservesPartitioning) def flatMap(self, f): """ @@ -38,7 +38,8 @@ class RDD(object): >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - return MappedRDD(self, f, preservesPartitioning=False, command='flatmap') + def func(iterator): return chain.from_iterable(imap(f, iterator)) + return PipelinedRDD(self, func) def filter(self, f): """ @@ -46,10 +47,10 @@ class RDD(object): >>> rdd.filter(lambda x: x % 2 == 0).collect() [2, 4] """ - def filter_func(x): return x if f(x) else None - return RDD(self._pipe(filter_func), self.ctx) + def func(iterator): return ifilter(f, iterator) + return PipelinedRDD(self, func) - def _pipe(self, functions, command="map"): + def _pipe(self, functions, command): class_manifest = self._jrdd.classManifest() pipe_command = RDD._get_pipe_command(command, functions) python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, @@ -128,7 +129,16 @@ class RDD(object): >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) 10 """ - vals = MappedRDD(self, f, command="reduce", preservesPartitioning=False).collect() + def func(iterator): + acc = None + for obj in iterator: + if acc is None: + acc = obj + else: + acc = f(obj, acc) + if acc is not None: + yield acc + vals = PipelinedRDD(self, func).collect() return reduce(f, vals) # TODO: fold @@ -230,8 +240,6 @@ class RDD(object): jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) - - def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): """ @@ -297,7 +305,7 @@ class RDD(object): # TODO: file saving -class MappedRDD(RDD): +class PipelinedRDD(RDD): """ Pipelined maps: >>> rdd = sc.parallelize([1, 2, 3, 4]) @@ -313,68 +321,29 @@ class MappedRDD(RDD): >>> rdd.flatMap(lambda x: [x, x]).reduce(add) 20 """ - def __init__(self, prev, func, preservesPartitioning=False, command='map'): - if isinstance(prev, MappedRDD) and not prev.is_cached: + def __init__(self, prev, func, preservesPartitioning=False): + if isinstance(prev, PipelinedRDD) and not prev.is_cached: prev_func = prev.func - if command == 'reduce': - if prev.command == 'flatmap': - def flatmap_reduce_func(x, acc): - values = prev_func(x) - if values is None: - return acc - if not acc: - if len(values) == 1: - return values[0] - else: - return reduce(func, values[1:], values[0]) - else: - return reduce(func, values, acc) - self.func = flatmap_reduce_func - else: - def reduce_func(x, acc): - val = prev_func(x) - if not val: - return acc - if acc is None: - return val - else: - return func(val, acc) - self.func = reduce_func - else: - if prev.command == 'flatmap': - command = 'flatmap' - self.func = lambda x: (func(y) for y in prev_func(x)) - else: - self.func = lambda x: func(prev_func(x)) - + def pipeline_func(iterator): + return func(prev_func(iterator)) + self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning self._prev_jrdd = prev._prev_jrdd - self.is_pipelined = True else: - if command == 'reduce': - def reduce_func(val, acc): - if acc is None: - return val - else: - return func(val, acc) - self.func = reduce_func - else: - self.func = func + self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd - self.is_pipelined = False self.is_cached = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None - self.command = command @property def _jrdd(self): if not self._jrdd_val: funcs = [self.func] - pipe_command = RDD._get_pipe_command(self.command, funcs) + pipe_command = RDD._get_pipe_command("pipeline", funcs) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index b13ed5699a..76b09918e7 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -25,17 +25,17 @@ def output(x): def read_input(): try: while True: - yield loads(sys.stdin) + yield cPickle.loads(loads(sys.stdin)) except EOFError: return + def do_combine_by_key(): create_combiner = load_function() merge_value = load_function() merge_combiners = load_function() # TODO: not used. combiners = {} - for obj in read_input(): - (key, value) = PickleSerializer.loads(obj) + for (key, value) in read_input(): if key not in combiners: combiners[key] = create_combiner(value) else: @@ -44,57 +44,32 @@ def do_combine_by_key(): output(PickleSerializer.dumps((key, combiner))) -def do_map(flat=False): +def do_pipeline(): f = load_function() - for obj in read_input(): - try: - out = f(PickleSerializer.loads(obj)) - if out is not None: - if flat: - for x in out: - output(PickleSerializer.dumps(x)) - else: - output(PickleSerializer.dumps(out)) - except: - sys.stderr.write("Error processing obj %s\n" % repr(obj)) - raise + for obj in f(read_input()): + output(PickleSerializer.dumps(obj)) def do_shuffle_map_step(): hashFunc = load_function() - for obj in read_input(): - key = PickleSerializer.loads(obj)[0] + while True: + try: + pickled = loads(sys.stdin) + except EOFError: + return + key = cPickle.loads(pickled)[0] output(str(hashFunc(key))) - output(obj) - - -def do_reduce(): - f = load_function() - acc = None - for obj in read_input(): - acc = f(PickleSerializer.loads(obj), acc) - if acc is not None: - output(PickleSerializer.dumps(acc)) - - -def do_echo(): - old_stdout.writelines(sys.stdin.readlines()) + output(pickled) def main(): command = sys.stdin.readline().strip() - if command == "map": - do_map(flat=False) - elif command == "flatmap": - do_map(flat=True) + if command == "pipeline": + do_pipeline() elif command == "combine_by_key": do_combine_by_key() - elif command == "reduce": - do_reduce() elif command == "shuffle_map_step": do_shuffle_map_step() - elif command == "echo": - do_echo() else: raise Exception("Unsupported command %s" % command) -- cgit v1.2.3 From 741899b21e4e6439459fcf4966076661c851ed07 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 16:26:06 -0700 Subject: Fix sendMessageReliablySync --- core/src/main/scala/spark/network/ConnectionManager.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 1a22d06cc8..66b822117f 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -14,7 +14,8 @@ import scala.collection.mutable.SynchronizedQueue import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer -import akka.dispatch.{Promise, ExecutionContext, Future} +import akka.dispatch.{Await, Promise, ExecutionContext, Future} +import akka.util.Duration case class ConnectionManagerId(host: String, port: Int) { def toSocketAddress() = new InetSocketAddress(host, port) @@ -325,7 +326,7 @@ class ConnectionManager(port: Int) extends Logging { } def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { - sendMessageReliably(connectionManagerId, message)() + Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) } def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { -- cgit v1.2.3 From ad6537321ebda407f28a7c32845708ad66684c55 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 16:27:23 -0700 Subject: Make Time serializable --- streaming/src/main/scala/spark/streaming/Time.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 5c476f02c3..3901a26286 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,6 +1,6 @@ package spark.streaming -class Time(private var millis: Long) { +class Time(private var millis: Long) extends Serializable { def copy() = new Time(this.millis) -- cgit v1.2.3 From 06ef7c3d1bf8446d4d6ef8f3a055dd1e6d32ca3a Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 16:29:20 -0700 Subject: Less debug info --- core/src/main/scala/spark/network/ConnectionManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 66b822117f..bd0980029a 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -248,7 +248,7 @@ class ConnectionManager(port: Int) extends Logging { } private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { - logInfo("Handling [" + message + "] from [" + connectionManagerId + "]") + logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") message match { case bufferMessage: BufferMessage => { if (bufferMessage.hasAckId) { -- cgit v1.2.3 From b08ff710af9b6592e3b43308ec4598bd3e6da084 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 26 Aug 2012 23:40:50 +0000 Subject: Added sliding word count, and some fixes to reduce window DStream --- .../spark/streaming/ReducedWindowedDStream.scala | 3 +- .../spark/streaming/examples/WordCount2.scala | 98 ++++++++++++++++++++++ .../spark/streaming/util/SenderReceiverTest.scala | 9 +- 3 files changed, 106 insertions(+), 4 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2.scala diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 11fa4e5443..d097896d0a 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -75,11 +75,12 @@ extends DStream[(K,V)](parent.ssc) { val previousWindow = getAdjustedWindow(currentTime - slideTime, windowTime) logInfo("Current window = " + currentWindow) + logInfo("Slide time = " + slideTime) logInfo("Previous window = " + previousWindow) logInfo("Parent.zeroTime = " + parent.zeroTime) if (allowPartialWindows) { - if (currentTime - slideTime == parent.zeroTime) { + if (currentTime - slideTime <= parent.zeroTime) { reducedStream.getOrCompute(currentTime) match { case Some(rdd) => return Some(rdd) case None => throw new Exception("Could not get first reduced RDD for time " + currentTime) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala new file mode 100644 index 0000000000..83cbd31283 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -0,0 +1,98 @@ +package spark.streaming.examples + +import spark.SparkContext +import SparkContext._ +import spark.streaming._ +import SparkStreamContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + + +object WordCount2_ExtraFunctions { + + def add(v1: Long, v2: Long) = (v1 + v2) + + def subtract(v1: Long, v2: Long) = (v1 - v2) + + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { + //val map = new java.util.HashMap[String, Long] + val map = new OLMap[String] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.getLong(w) + map.put(w, c + 1) +/* + if (c == null) { + map.put(w, 1) + } else { + map.put(w, c + 1) + } +*/ + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator.map{case (k, v) => (k, v)} + } +} + +object WordCount2 { + + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + + def main (args: Array[String]) { + + if (args.length < 2) { + println ("Usage: SparkStreamContext ") + System.exit(1) + } + + val ssc = new SparkStreamContext(args(0), "WordCount2") + ssc.setBatchDuration(Seconds(1)) + + val sentences = new ConstantInputDStream(ssc, ssc.sc.textFile(args(1)).cache()) + ssc.inputStreams += sentences + + import WordCount2_ExtraFunctions._ + + val windowedCounts = sentences + .mapPartitions(splitAndCountPartitions) + .reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(10)) + windowedCounts.print() + + ssc.start() + + while(true) { Thread.sleep(1000) } + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala index 9fb1924798..3922dfbad6 100644 --- a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala +++ b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala @@ -19,9 +19,12 @@ object Receiver { val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) var loop = true var string: String = null - while((string = is.readUTF) != null) { - count += 28 - } + do { + string = is.readUTF() + if (string != null) { + count += 28 + } + } while (string != null) } catch { case e: Exception => e.printStackTrace() } -- cgit v1.2.3 From b120e24fe04b987cfb5c487251fc10e36b377d6b Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 16:45:14 -0700 Subject: Add equals and hashCode to Time --- streaming/src/main/scala/spark/streaming/Time.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 3901a26286..7835bdd5e8 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -55,6 +55,14 @@ class Time(private var millis: Long) extends Serializable { def toFormattedString = millis.toString def milliseconds = millis + + override def hashCode = millis.toInt + + override def equals(other: Any): Boolean = other match { + case null => false + case t: Time => t.millis == millis + case _ => false + } } object Time { -- cgit v1.2.3 From 22b1a20e61e0573163b8896a604de78ebec65071 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 17:04:34 -0700 Subject: Made Time and Interval immutable --- .../src/main/scala/spark/streaming/DStream.scala | 2 +- .../src/main/scala/spark/streaming/Interval.scala | 8 +--- .../spark/streaming/ReducedWindowedDStream.scala | 2 +- .../src/main/scala/spark/streaming/Time.scala | 50 ++++------------------ .../scala/spark/streaming/WindowedDStream.scala | 2 +- 5 files changed, 13 insertions(+), 51 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index c63c043415..03a5b84210 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -130,7 +130,7 @@ extends Logging with Serializable { newRDD.persist(storageLevel) logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) } - generatedRDDs.put(time.copy(), newRDD) + generatedRDDs.put(time, newRDD) Some(newRDD) case None => None diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index 088cbe4376..87b8437b3d 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -1,15 +1,9 @@ package spark.streaming -case class Interval (beginTime: Time, endTime: Time) { +case class Interval(beginTime: Time, endTime: Time) { def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) def duration(): Time = endTime - beginTime - - def += (time: Time) { - beginTime += time - endTime += time - this - } def + (time: Time): Interval = { new Interval(beginTime + time, endTime + time) diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index d097896d0a..7833fa6189 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -70,7 +70,7 @@ extends DStream[(K,V)](parent.ssc) { Interval(beginTime, endTime) } - val currentTime = validTime.copy + val currentTime = validTime val currentWindow = getAdjustedWindow(currentTime, windowTime) val previousWindow = getAdjustedWindow(currentTime - slideTime, windowTime) diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 7835bdd5e8..6937e8b52f 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,43 +1,21 @@ package spark.streaming -class Time(private var millis: Long) extends Serializable { - - def copy() = new Time(this.millis) - +case class Time(millis: Long) { def zero = Time.zero - def < (that: Time): Boolean = - (this.millis < that.millis) + def < (that: Time): Boolean = (this.millis < that.millis) - def <= (that: Time) = (this < that || this == that) + def <= (that: Time) = (this.millis <= that.millis) - def > (that: Time) = !(this <= that) + def > (that: Time) = (this.millis > that.millis) - def >= (that: Time) = !(this < that) - - def += (that: Time): Time = { - this.millis += that.millis - this - } - - def -= (that: Time): Time = { - this.millis -= that.millis - this - } + def >= (that: Time) = (this.millis >= that.millis) - def + (that: Time) = this.copy() += that + def + (that: Time) = new Time(millis + that.millis) - def - (that: Time) = this.copy() -= that + def - (that: Time) = new Time(millis - that.millis) - def * (times: Int) = { - var count = 0 - var result = this.copy() - while (count < times) { - result += this - count += 1 - } - result - } + def * (times: Int) = new Time(millis * times) def floor(that: Time): Time = { val t = that.millis @@ -55,21 +33,11 @@ class Time(private var millis: Long) extends Serializable { def toFormattedString = millis.toString def milliseconds = millis - - override def hashCode = millis.toInt - - override def equals(other: Any): Boolean = other match { - case null => false - case t: Time => t.millis == millis - case _ => false - } } object Time { val zero = new Time(0) - - def apply(milliseconds: Long) = new Time(milliseconds) - + implicit def toTime(long: Long) = Time(long) implicit def toLong(time: Time) = time.milliseconds diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala index 9a6617a1ee..17f3f3d952 100644 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -32,7 +32,7 @@ class WindowedDStream[T: ClassManifest]( override def compute(validTime: Time): Option[RDD[T]] = { val parentRDDs = new ArrayBuffer[RDD[T]]() - val windowEndTime = validTime.copy() + val windowEndTime = validTime val windowStartTime = if (allowPartialWindows && windowEndTime - windowTime < parent.zeroTime) { parent.zeroTime } else { -- cgit v1.2.3 From 57796b183e18730f671f8b970fbf8880875d9f03 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 17:25:22 -0700 Subject: Code style --- .../src/main/scala/spark/streaming/Time.scala | 26 ++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 6937e8b52f..7841a71cd5 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,42 +1,40 @@ package spark.streaming case class Time(millis: Long) { - def zero = Time.zero - def < (that: Time): Boolean = (this.millis < that.millis) - def <= (that: Time) = (this.millis <= that.millis) + def <= (that: Time): Boolean = (this.millis <= that.millis) - def > (that: Time) = (this.millis > that.millis) + def > (that: Time): Boolean = (this.millis > that.millis) - def >= (that: Time) = (this.millis >= that.millis) + def >= (that: Time): Boolean = (this.millis >= that.millis) - def + (that: Time) = new Time(millis + that.millis) + def + (that: Time): Time = Time(millis + that.millis) - def - (that: Time) = new Time(millis - that.millis) + def - (that: Time): Time = Time(millis - that.millis) - def * (times: Int) = new Time(millis * times) + def * (times: Int): Time = Time(millis * times) def floor(that: Time): Time = { val t = that.millis val m = math.floor(this.millis / t).toLong - new Time(m * t) + Time(m * t) } def isMultipleOf(that: Time): Boolean = (this.millis % that.millis == 0) - def isZero = (this.millis == 0) + def isZero: Boolean = (this.millis == 0) - override def toString = (millis.toString + " ms") + override def toString: String = (millis.toString + " ms") - def toFormattedString = millis.toString + def toFormattedString: String = millis.toString - def milliseconds = millis + def milliseconds: Long = millis } object Time { - val zero = new Time(0) + val zero = Time(0) implicit def toTime(long: Long) = Time(long) -- cgit v1.2.3 From 9de1c3abf90fff82901c1ee13297d436d2d7a25d Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Aug 2012 00:57:00 +0000 Subject: Tweaks to WordCount2 --- .../scala/spark/streaming/examples/WordCount2.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 83cbd31283..87b62817ea 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -70,23 +70,30 @@ object WordCount2 { def main (args: Array[String]) { - if (args.length < 2) { - println ("Usage: SparkStreamContext ") + if (args.length < 4) { + println ("Usage: SparkStreamContext ") System.exit(1) } + + val Array(master, file, mapTasks, reduceTasks) = args - val ssc = new SparkStreamContext(args(0), "WordCount2") + val ssc = new SparkStreamContext(master, "WordCount2") ssc.setBatchDuration(Seconds(1)) + + val data = ssc.sc.textFile(file, mapTasks.toInt).persist(StorageLevel.MEMORY_ONLY_DESER_2) + println("Data count: " + data.count()) + println("Data count: " + data.count()) + println("Data count: " + data.count()) - val sentences = new ConstantInputDStream(ssc, ssc.sc.textFile(args(1)).cache()) + val sentences = new ConstantInputDStream(ssc, data) ssc.inputStreams += sentences import WordCount2_ExtraFunctions._ val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) - .reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), 10) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER, Seconds(10)) + .reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), reduceTasks.toInt) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) windowedCounts.print() ssc.start() -- cgit v1.2.3 From 29e83f39e90b4d3cbeeb40d5ec0c19bd003c1840 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 18:16:25 -0700 Subject: Fix replication with MEMORY_ONLY_DESER_2 --- core/src/main/scala/spark/storage/BlockManager.scala | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index ff9914ae25..45f99717bc 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -364,6 +364,12 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val startTimeMs = System.currentTimeMillis var bytes: ByteBuffer = null + + // If we need to replicate the data, we'll want access to the values, but because our + // put will read the whole iterator, there will be no values left. For the case where + // the put serializes data, we'll remember the bytes, above; but for the case where + // it doesn't, such as MEMORY_ONLY_DESER, let's rely on the put returning an Iterator. + var valuesAfterPut: Iterator[Any] = null locker.getLock(blockId).synchronized { logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) @@ -391,7 +397,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // If only save to memory memoryStore.putValues(blockId, values, level) match { case Right(newBytes) => bytes = newBytes - case _ => + case Left(newIterator) => valuesAfterPut = newIterator } } else { // If only save to disk @@ -408,8 +414,13 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Replicate block if required if (level.replication > 1) { + // Serialize the block if not already done if (bytes == null) { - bytes = dataSerialize(values) // serialize the block if not already done + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytes = dataSerialize(valuesAfterPut) } replicate(blockId, bytes, level) } -- cgit v1.2.3 From 26dfd20c9a5139bd682a9902267b9d54a11ae20f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 18:56:56 -0700 Subject: Detect disconnected slaves in StandaloneScheduler --- .../cluster/StandaloneSchedulerBackend.scala | 38 ++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 013671c1c8..83e7c6e036 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -2,13 +2,14 @@ package spark.scheduler.cluster import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import akka.actor.{Props, Actor, ActorRef, ActorSystem} +import akka.actor._ import akka.util.duration._ import akka.pattern.ask import spark.{SparkException, Logging, TaskState} import akka.dispatch.Await import java.util.concurrent.atomic.AtomicInteger +import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} /** * A standalone scheduler backend, which waits for standalone executors to connect to it through @@ -23,8 +24,16 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { val slaveActor = new HashMap[String, ActorRef] + val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] + val actorToSlaveId = new HashMap[ActorRef, String] + val addressToSlaveId = new HashMap[Address, String] + + override def preStart() { + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + } def receive = { case RegisterSlave(slaveId, host, cores) => @@ -33,9 +42,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } else { logInfo("Registered slave: " + sender + " with ID " + slaveId) sender ! RegisteredSlave(sparkProperties) + context.watch(sender) slaveActor(slaveId) = sender slaveHost(slaveId) = host freeCores(slaveId) = cores + slaveAddress(slaveId) = sender.path.address + actorToSlaveId(sender) = slaveId + addressToSlaveId(sender.path.address) = slaveId totalCoreCount.addAndGet(cores) makeOffers() } @@ -54,7 +67,14 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! true context.stop(self) - // TODO: Deal with nodes disconnecting too! (Including decreasing totalCoreCount) + case Terminated(actor) => + actorToSlaveId.get(actor).foreach(removeSlave) + + case RemoteClientDisconnected(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) + + case RemoteClientShutdown(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) } // Make fake resource offers on all slaves @@ -76,6 +96,20 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor slaveActor(task.slaveId) ! LaunchTask(task) } } + + // Remove a disconnected slave from the cluster + def removeSlave(slaveId: String) { + logInfo("Slave " + slaveId + " disconnected, so removing it") + val numCores = freeCores(slaveId) + actorToSlaveId -= slaveActor(slaveId) + addressToSlaveId -= slaveAddress(slaveId) + slaveActor -= slaveId + slaveHost -= slaveId + freeCores -= slaveId + slaveHost -= slaveId + totalCoreCount.addAndGet(-numCores) + scheduler.slaveLost(slaveId) + } } var masterActor: ActorRef = null -- cgit v1.2.3 From 3c9c44a8d36c0f2dff40a50f1f1e3bc3dac7be7e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 19:37:43 -0700 Subject: More helpful log messages --- core/src/main/scala/spark/MapOutputTracker.scala | 3 ++- core/src/main/scala/spark/network/Connection.scala | 4 ++-- core/src/main/scala/spark/network/ConnectionManager.scala | 2 +- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 1 + 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 0c97cd44a1..e249430905 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -116,7 +116,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = { val locs = bmAddresses.get(shuffleId) if (locs == null) { - logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them") + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -158,6 +158,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def incrementGeneration() { generationLock.synchronized { generation += 1 + logInfo("Increasing generation to " + generation) } } diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 451faee66e..da8aff9dd5 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -111,7 +111,7 @@ extends Connection(SocketChannel.open, selector_) { messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") } } @@ -136,7 +136,7 @@ extends Connection(SocketChannel.open, selector_) { return chunk } /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ - logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) + logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) } } None diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index bd0980029a..0e764fff81 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -306,7 +306,7 @@ class ConnectionManager(port: Int) extends Logging { } val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection()) message.senderAddress = id.toSocketAddress() - logInfo("Sending [" + message + "] to [" + connectionManagerId + "]") + logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") /*connection.send(message)*/ sendMessageRequests.synchronized { sendMessageRequests += ((message, connection)) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index be24316e80..5412e8d8c0 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -88,6 +88,7 @@ class TaskSetManager( // Figure out the current map output tracker generation and set it on all tasks val generation = sched.mapOutputTracker.getGeneration + logInfo("Generation for " + taskSet.id + ": " + generation) for (t <- tasks) { t.generation = generation } -- cgit v1.2.3 From 117e3f8c8602c1303fa0e31840d85d1a7a6e3d9d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 19:52:56 -0700 Subject: Fix a bug that was causing FetchFailedException not to be thrown --- core/src/main/scala/spark/BlockStoreShuffleFetcher.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 3431ad2258..45a14c8290 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -48,8 +48,9 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { } } } catch { + // TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException case be: BlockException => { - val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r + val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r be.blockId match { case regex(sId, mId, rId) => { val address = addresses(mId.toInt) -- cgit v1.2.3 From 69c2ab04083972e4ecf1393ffd0cb0acb56b4f7d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 20:00:58 -0700 Subject: logging --- core/src/main/scala/spark/executor/Executor.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 9e335c25f7..dba209ac27 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -63,6 +63,7 @@ class Executor extends Logging { Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear() val task = ser.deserialize[Task[Any]](serializedTask, classLoader) + logInfo("Its generation is " + task.generation) env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(taskId.toInt) val accumUpdates = Accumulators.values -- cgit v1.2.3 From b914cd0dfa21b615c29d2ce935f623f209afa8f4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 26 Aug 2012 20:07:59 -0700 Subject: Serialize generation correctly in ShuffleMapTask --- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index f78e0e5fb2..73479bff01 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -90,6 +90,7 @@ class ShuffleMapTask( out.writeInt(bytes.length) out.write(bytes) out.writeInt(partition) + out.writeLong(generation) out.writeObject(split) } @@ -102,6 +103,7 @@ class ShuffleMapTask( rdd = rdd_ dep = dep_ partition = in.readInt() + generation = in.readLong() split = in.readObject().asInstanceOf[Split] } -- cgit v1.2.3 From e2cf197a0a878f9c942dafe98a70bdaefb5df58d Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Aug 2012 03:34:15 +0000 Subject: Made WordCount2 even more configurable --- .../src/main/scala/spark/streaming/examples/WordCount2.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 87b62817ea..1afe87e723 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -70,15 +70,17 @@ object WordCount2 { def main (args: Array[String]) { - if (args.length < 4) { - println ("Usage: SparkStreamContext ") + if (args.length != 5) { + println ("Usage: SparkStreamContext ") System.exit(1) } - val Array(master, file, mapTasks, reduceTasks) = args + val Array(master, file, mapTasks, reduceTasks, batchMillis) = args + + val BATCH_DURATION = Milliseconds(batchMillis.toLong) val ssc = new SparkStreamContext(master, "WordCount2") - ssc.setBatchDuration(Seconds(1)) + ssc.setBatchDuration(BATCH_DURATION) val data = ssc.sc.textFile(file, mapTasks.toInt).persist(StorageLevel.MEMORY_ONLY_DESER_2) println("Data count: " + data.count()) @@ -92,7 +94,7 @@ object WordCount2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) - .reduceByKeyAndWindow(add _, subtract _, Seconds(10), Seconds(1), reduceTasks.toInt) + .reduceByKeyAndWindow(add _, subtract _, Seconds(10), BATCH_DURATION, reduceTasks.toInt) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) windowedCounts.print() -- cgit v1.2.3 From 65e8406029a0fe1e1c5c5d033d335b43f6743a04 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Aug 2012 21:07:26 -0700 Subject: Implement fold() in Python API. --- pyspark/pyspark/rdd.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 7d280d8844..af7703fdfc 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -141,7 +141,25 @@ class RDD(object): vals = PipelinedRDD(self, func).collect() return reduce(f, vals) - # TODO: fold + def fold(self, zeroValue, op): + """ + Aggregate the elements of each partition, and then the results for all + the partitions, using a given associative function and a neutral "zero + value." The function op(t1, t2) is allowed to modify t1 and return it + as its result value to avoid object allocation; however, it should not + modify t2. + + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) + 15 + """ + def func(iterator): + acc = zeroValue + for obj in iterator: + acc = op(obj, acc) + yield acc + vals = PipelinedRDD(self, func).collect() + return reduce(op, vals, zeroValue) # TODO: aggregate -- cgit v1.2.3 From f79a1e4d2a8643157136de69b8d7de84f0034712 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 13:59:01 -0700 Subject: Add broadcast variables to Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 43 ++++++++++++-------- pyspark/pyspark/broadcast.py | 46 ++++++++++++++++++++++ pyspark/pyspark/context.py | 17 ++++++-- pyspark/pyspark/rdd.py | 27 ++++++++----- pyspark/pyspark/worker.py | 6 +++ 5 files changed, 110 insertions(+), 29 deletions(-) create mode 100644 pyspark/pyspark/broadcast.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 93847e2f14..5163812df4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -7,14 +7,13 @@ import scala.collection.JavaConversions._ import scala.io.Source import spark._ import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import scala.{collection, Some} -import collection.parallel.mutable +import broadcast.Broadcast import scala.collection -import scala.Some trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], - command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = { + command: Seq[String], parent: RDD[T], pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) @@ -42,11 +41,18 @@ trait PythonRDDBase { override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) + val dOut = new DataOutputStream(proc.getOutputStream) + out.println(broadcastVars.length) + for (broadcast <- broadcastVars) { + out.print(broadcast.uuid.toString) + dOut.writeInt(broadcast.value.length) + dOut.write(broadcast.value) + dOut.flush() + } for (elem <- command) { out.println(elem) } out.flush() - val dOut = new DataOutputStream(proc.getOutputStream) for (elem <- parent.iterator(split)) { if (elem.isInstanceOf[Array[Byte]]) { val arr = elem.asInstanceOf[Array[Byte]] @@ -121,16 +127,17 @@ trait PythonRDDBase { class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String) + preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = - this(parent, command, Map(), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, + pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) override def splits = parent.splits @@ -139,23 +146,25 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[Array[Byte]] = - compute(split, envVars, command, parent, pythonExec) + compute(split, envVars, command, parent, pythonExec, broadcastVars) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } class PythonPairRDD[T: ClassManifest] ( parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String) + preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = - this(parent, command, Map(), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, + pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) override def splits = parent.splits @@ -164,7 +173,7 @@ class PythonPairRDD[T: ClassManifest] ( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { - compute(split, envVars, command, parent, pythonExec).grouped(2).map { + compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map { case Seq(a, b) => (a, b) case x => throw new Exception("PythonPairRDD: unexpected value: " + x) } diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py new file mode 100644 index 0000000000..1ea17d59af --- /dev/null +++ b/pyspark/pyspark/broadcast.py @@ -0,0 +1,46 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> b = sc.broadcast([1, 2, 3, 4, 5]) +>>> b.value +[1, 2, 3, 4, 5] + +>>> from pyspark.broadcast import _broadcastRegistry +>>> _broadcastRegistry[b.uuid] = b +>>> from cPickle import dumps, loads +>>> loads(dumps(b)).value +[1, 2, 3, 4, 5] + +>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() +[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +""" +# Holds broadcasted data received from Java, keyed by UUID. +_broadcastRegistry = {} + + +def _from_uuid(uuid): + from pyspark.broadcast import _broadcastRegistry + if uuid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % uuid) + return _broadcastRegistry[uuid] + + +class Broadcast(object): + def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None): + self.value = value + self.uuid = uuid + self._jbroadcast = java_broadcast + self._pickle_registry = pickle_registry + + def __reduce__(self): + self._pickle_registry.add(self) + return (_from_uuid, (self.uuid, )) + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index ac7e4057e9..6f87206665 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -2,6 +2,7 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, dumps from pyspark.rdd import RDD @@ -24,6 +25,11 @@ class SparkContext(object): self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = pythonExec + # Broadcast's __reduce__ method stores Broadcast instances here. + # This allows other code to determine which Broadcast instances have + # been pickled, so it can determine which Java broadcast objects to + # send. + self._pickled_broadcast_vars = set() def __del__(self): if self._jsc: @@ -52,7 +58,12 @@ class SparkContext(object): jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) - def textFile(self, name, numSlices=None): - numSlices = numSlices or self.defaultParallelism - jrdd = self._jsc.textFile(name, numSlices) + def textFile(self, name, minSplits=None): + minSplits = minSplits or min(self.defaultParallelism, 2) + jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + + def broadcast(self, value): + jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value))) + return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, + self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index af7703fdfc..4459095391 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup +from py4j.java_collections import ListConverter + class RDD(object): @@ -15,11 +17,15 @@ class RDD(object): self.ctx = ctx @classmethod - def _get_pipe_command(cls, command, functions): + def _get_pipe_command(cls, ctx, command, functions): worker_args = [command] for f in functions: worker_args.append(b64enc(cloudpickle.dumps(f))) - return " ".join(worker_args) + broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars] + broadcast_vars = ListConverter().convert(broadcast_vars, + ctx.gateway._gateway_client) + ctx._pickled_broadcast_vars.clear() + return (" ".join(worker_args), broadcast_vars) def cache(self): self.is_cached = True @@ -52,9 +58,10 @@ class RDD(object): def _pipe(self, functions, command): class_manifest = self._jrdd.classManifest() - pipe_command = RDD._get_pipe_command(command, functions) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, command, functions) python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, class_manifest) + False, self.ctx.pythonExec, broadcast_vars, class_manifest) return python_rdd.asJavaRDD() def distinct(self): @@ -249,10 +256,12 @@ class RDD(object): def shuffle(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc]) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc]) class_manifest = self._jrdd.classManifest() python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), - pipe_command, False, self.ctx.pythonExec, class_manifest) + pipe_command, False, self.ctx.pythonExec, broadcast_vars, + class_manifest) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) @@ -360,12 +369,12 @@ class PipelinedRDD(RDD): @property def _jrdd(self): if not self._jrdd_val: - funcs = [self.func] - pipe_command = RDD._get_pipe_command("pipeline", funcs) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, "pipeline", [self.func]) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - class_manifest) + broadcast_vars, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 76b09918e7..7402897ac8 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -5,6 +5,7 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import dumps, loads, PickleSerializer import cPickle @@ -63,6 +64,11 @@ def do_shuffle_map_step(): def main(): + num_broadcast_variables = int(sys.stdin.readline().strip()) + for _ in range(num_broadcast_variables): + uuid = sys.stdin.read(36) + value = loads(sys.stdin) + _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value)) command = sys.stdin.readline().strip() if command == "pipeline": do_pipeline() -- cgit v1.2.3 From 08b201d810c0dc0933d00d78ec2c1d9135e100c3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Aug 2012 22:51:45 -0700 Subject: Add mapPartitions(), glom(), countByValue() to Python API. --- pyspark/pyspark/rdd.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 4459095391..f0d665236a 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,4 +1,5 @@ from base64 import standard_b64encode as b64enc +from collections import Counter from itertools import chain, ifilter, imap from pyspark import cloudpickle @@ -47,6 +48,15 @@ class RDD(object): def func(iterator): return chain.from_iterable(imap(f, iterator)) return PipelinedRDD(self, func) + def mapPartitions(self, f): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> def f(iterator): yield sum(iterator) + >>> rdd.mapPartitions(f).collect() + [3, 7] + """ + return PipelinedRDD(self, f) + def filter(self, f): """ >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) @@ -93,7 +103,14 @@ class RDD(object): # TODO: Overload __add___? - # TODO: glom + def glom(self): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> rdd.glom().first() + [1, 2] + """ + def func(iterator): yield list(iterator) + return PipelinedRDD(self, func) def cartesian(self, other): """ @@ -115,8 +132,6 @@ class RDD(object): # TODO: pipe - # TODO: mapPartitions - def foreach(self, f): """ >>> def f(x): print x @@ -177,7 +192,16 @@ class RDD(object): """ return self._jrdd.count() - # TODO: count approx methods + def countByValue(self): + """ + >>> sc.parallelize([1, 2, 1, 2, 2]).countByValue().most_common() + [(2, 3), (1, 2)] + """ + def countPartition(iterator): + yield Counter(iterator) + def mergeMaps(m1, m2): + return m1 + m2 + return self.mapPartitions(countPartition).reduce(mergeMaps) def take(self, num): """ -- cgit v1.2.3 From 8b64b7ecd80c52f2f09a517f1517c0ece7a3d57f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Aug 2012 23:09:15 -0700 Subject: Add countByKey(), reduceByKeyLocally() to Python API --- pyspark/pyspark/rdd.py | 52 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index f0d665236a..fd41ea0b17 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -99,9 +99,17 @@ class RDD(object): """ return RDD(self._jrdd.union(other._jrdd), self.ctx) - # TODO: sort + def __add__(self, other): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> (rdd + rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + if not isinstance(other, RDD): + raise TypeError + return self.union(other) - # TODO: Overload __add___? + # TODO: sort def glom(self): """ @@ -120,7 +128,6 @@ class RDD(object): """ return RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - # numsplits def groupBy(self, f, numSplits=None): """ >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) @@ -236,17 +243,38 @@ class RDD(object): def reduceByKey(self, func, numSplits=None): """ - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(x.reduceByKey(lambda a, b: a + b).collect()) + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ return self.combineByKey(lambda x: x, func, func, numSplits) - # TODO: reduceByKeyLocally() - - # TODO: countByKey() + def reduceByKeyLocally(self, func): + """ + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKeyLocally(add).items()) + [('a', 2), ('b', 1)] + """ + def reducePartition(iterator): + m = {} + for (k, v) in iterator: + m[k] = v if k not in m else func(m[k], v) + yield m + def mergeMaps(m1, m2): + for (k, v) in m2.iteritems(): + m1[k] = v if k not in m1 else func(m1[k], v) + return m1 + return self.mapPartitions(reducePartition).reduce(mergeMaps) - # TODO: partitionBy + def countByKey(self): + """ + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> rdd.countByKey().most_common() + [('a', 2), ('b', 1)] + """ + return self.map(lambda x: x[0]).countByValue() def join(self, other, numSplits=None): """ @@ -277,7 +305,7 @@ class RDD(object): # TODO: pipelining # TODO: optimizations - def shuffle(self, numSplits, hashFunc=hash): + def partitionBy(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism (pipe_command, broadcast_vars) = \ @@ -302,7 +330,7 @@ class RDD(object): """ if numSplits is None: numSplits = self.ctx.defaultParallelism - shuffled = self.shuffle(numSplits) + shuffled = self.partitionBy(numSplits) functions = [createCombiner, mergeValue, mergeCombiners] jpairs = shuffled._pipe(functions, "combine_by_key") return RDD(jpairs, self.ctx) @@ -353,8 +381,6 @@ class RDD(object): # keys in the pairs. This could be an expensive operation, since those # hashes aren't retained. - # TODO: file saving - class PipelinedRDD(RDD): """ -- cgit v1.2.3 From 6904cb77d4306a14891cc71338c8f9f966d009f1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 14:19:07 -0700 Subject: Use local combiners in Python API combineByKey(). --- pyspark/pyspark/rdd.py | 33 ++++++++++++++++++++++++--------- pyspark/pyspark/worker.py | 16 ---------------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index fd41ea0b17..3528b8f308 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -46,7 +46,7 @@ class RDD(object): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(iterator): return chain.from_iterable(imap(f, iterator)) - return PipelinedRDD(self, func) + return self.mapPartitions(func) def mapPartitions(self, f): """ @@ -64,7 +64,7 @@ class RDD(object): [2, 4] """ def func(iterator): return ifilter(f, iterator) - return PipelinedRDD(self, func) + return self.mapPartitions(func) def _pipe(self, functions, command): class_manifest = self._jrdd.classManifest() @@ -118,7 +118,7 @@ class RDD(object): [1, 2] """ def func(iterator): yield list(iterator) - return PipelinedRDD(self, func) + return self.mapPartitions(func) def cartesian(self, other): """ @@ -167,7 +167,7 @@ class RDD(object): acc = f(obj, acc) if acc is not None: yield acc - vals = PipelinedRDD(self, func).collect() + vals = self.mapPartitions(func).collect() return reduce(f, vals) def fold(self, zeroValue, op): @@ -187,7 +187,7 @@ class RDD(object): for obj in iterator: acc = op(obj, acc) yield acc - vals = PipelinedRDD(self, func).collect() + vals = self.mapPartitions(func).collect() return reduce(op, vals, zeroValue) # TODO: aggregate @@ -330,10 +330,25 @@ class RDD(object): """ if numSplits is None: numSplits = self.ctx.defaultParallelism - shuffled = self.partitionBy(numSplits) - functions = [createCombiner, mergeValue, mergeCombiners] - jpairs = shuffled._pipe(functions, "combine_by_key") - return RDD(jpairs, self.ctx) + def combineLocally(iterator): + combiners = {} + for (k, v) in iterator: + if k not in combiners: + combiners[k] = createCombiner(v) + else: + combiners[k] = mergeValue(combiners[k], v) + return combiners.iteritems() + locally_combined = self.mapPartitions(combineLocally) + shuffled = locally_combined.partitionBy(numSplits) + def _mergeCombiners(iterator): + combiners = {} + for (k, v) in iterator: + if not k in combiners: + combiners[k] = v + else: + combiners[k] = mergeCombiners(combiners[k], v) + return combiners.iteritems() + return shuffled.mapPartitions(_mergeCombiners) def groupByKey(self, numSplits=None): """ diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 7402897ac8..0f90c6ff46 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -31,20 +31,6 @@ def read_input(): return -def do_combine_by_key(): - create_combiner = load_function() - merge_value = load_function() - merge_combiners = load_function() # TODO: not used. - combiners = {} - for (key, value) in read_input(): - if key not in combiners: - combiners[key] = create_combiner(value) - else: - combiners[key] = merge_value(combiners[key], value) - for (key, combiner) in combiners.iteritems(): - output(PickleSerializer.dumps((key, combiner))) - - def do_pipeline(): f = load_function() for obj in f(read_input()): @@ -72,8 +58,6 @@ def main(): command = sys.stdin.readline().strip() if command == "pipeline": do_pipeline() - elif command == "combine_by_key": - do_combine_by_key() elif command == "shuffle_map_step": do_shuffle_map_step() else: -- cgit v1.2.3 From 200d248dcc5903295296bf897211cf543b37f8c1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 16:46:07 -0700 Subject: Simplify Python worker; pipeline the map step of partitionBy(). --- .../main/scala/spark/api/python/PythonRDD.scala | 34 +++-------- pyspark/pyspark/context.py | 9 ++- pyspark/pyspark/rdd.py | 70 +++++++++------------- pyspark/pyspark/serializers.py | 23 ++----- pyspark/pyspark/worker.py | 50 +++++----------- 5 files changed, 59 insertions(+), 127 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 5163812df4..b9091fd436 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -151,38 +151,18 @@ class PythonRDD[T: ClassManifest]( val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } -class PythonPairRDD[T: ClassManifest] ( - parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) - extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { - - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, - pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) - - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) - - override def splits = parent.splits - - override val dependencies = List(new OneToOneDependency(parent)) - - override val partitioner = if (preservePartitoning) parent.partitioner else None - - override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { - compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map { +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends + RDD[(Array[Byte], Array[Byte])](prev.context) { + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = + prev.iterator(split).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("PythonPairRDD: unexpected value: " + x) + case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } - } - val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } - object PythonRDD { /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 6f87206665..b8490019e3 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway -from pyspark.serializers import PickleSerializer, dumps +from pyspark.serializers import dump_pickle, write_with_length from pyspark.rdd import RDD @@ -16,9 +16,8 @@ class SparkContext(object): asPickle = jvm.spark.api.python.PythonRDD.asPickle arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultParallelism=None, - pythonExec='python'): + pythonExec='python'): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) @@ -52,7 +51,7 @@ class SparkContext(object): # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) for x in c: - dumps(PickleSerializer.dumps(x), tempFile) + write_with_length(dump_pickle(x), tempFile) tempFile.close() atexit.register(lambda: os.unlink(tempFile.name)) jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) @@ -64,6 +63,6 @@ class SparkContext(object): return RDD(jrdd, self) def broadcast(self, value): - jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value))) + jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 3528b8f308..21e822ba9f 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -3,7 +3,7 @@ from collections import Counter from itertools import chain, ifilter, imap from pyspark import cloudpickle -from pyspark.serializers import PickleSerializer +from pyspark.serializers import dump_pickle, load_pickle from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -17,17 +17,6 @@ class RDD(object): self.is_cached = False self.ctx = ctx - @classmethod - def _get_pipe_command(cls, ctx, command, functions): - worker_args = [command] - for f in functions: - worker_args.append(b64enc(cloudpickle.dumps(f))) - broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars] - broadcast_vars = ListConverter().convert(broadcast_vars, - ctx.gateway._gateway_client) - ctx._pickled_broadcast_vars.clear() - return (" ".join(worker_args), broadcast_vars) - def cache(self): self.is_cached = True self._jrdd.cache() @@ -66,14 +55,6 @@ class RDD(object): def func(iterator): return ifilter(f, iterator) return self.mapPartitions(func) - def _pipe(self, functions, command): - class_manifest = self._jrdd.classManifest() - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, command, functions) - python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, broadcast_vars, class_manifest) - return python_rdd.asJavaRDD() - def distinct(self): """ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) @@ -89,7 +70,7 @@ class RDD(object): def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [PickleSerializer.loads(bytes(x)) for x in vals] + return [load_pickle(bytes(x)) for x in vals] def union(self, other): """ @@ -148,7 +129,7 @@ class RDD(object): def collect(self): pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) - return PickleSerializer.loads(bytes(pickle)) + return load_pickle(bytes(pickle)) def reduce(self, f): """ @@ -216,19 +197,17 @@ class RDD(object): [2, 3] """ pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) - return PickleSerializer.loads(bytes(pickle)) + return load_pickle(bytes(pickle)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first()))) + return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) # TODO: saveAsTextFile - # TODO: saveAsObjectFile - # Pair functions def collectAsMap(self): @@ -303,19 +282,18 @@ class RDD(object): """ return python_right_outer_join(self, other, numSplits) - # TODO: pipelining - # TODO: optimizations def partitionBy(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc]) - class_manifest = self._jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), - pipe_command, False, self.ctx.pythonExec, broadcast_vars, - class_manifest) + def add_shuffle_key(iterator): + for (k, v) in iterator: + yield str(hashFunc(k)) + yield dump_pickle((k, v)) + keyed = PipelinedRDD(self, add_shuffle_key) + keyed._bypass_serializer = True + pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) - jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) + jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) @@ -430,17 +408,23 @@ class PipelinedRDD(RDD): self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._bypass_serializer = False @property def _jrdd(self): - if not self._jrdd_val: - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, "pipeline", [self.func]) - class_manifest = self._prev_jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) - self._jrdd_val = python_rdd.asJavaRDD() + if self._jrdd_val: + return self._jrdd_val + funcs = [self.func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], + self.ctx.gateway._gateway_client) + self.ctx._pickled_broadcast_vars.clear() + class_manifest = self._prev_jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + broadcast_vars, class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 7b3e6966e1..faa1e683c7 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -1,31 +1,20 @@ -""" -Data serialization methods. - -The Spark Python API is built on top of the Spark Java API. RDDs created in -Python are stored in Java as RDD[Array[Byte]]. Python objects are -automatically serialized/deserialized, so this representation is transparent to -the end-user. -""" -from collections import namedtuple -import cPickle import struct +import cPickle -Serializer = namedtuple("Serializer", ["dumps","loads"]) +def dump_pickle(obj): + return cPickle.dumps(obj, 2) -PickleSerializer = Serializer( - lambda obj: cPickle.dumps(obj, -1), - cPickle.loads) +load_pickle = cPickle.loads -def dumps(obj, stream): - # TODO: determining the length of non-byte objects. +def write_with_length(obj, stream): stream.write(struct.pack("!i", len(obj))) stream.write(obj) -def loads(stream): +def read_with_length(stream): length = stream.read(4) if length == "": raise EOFError diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 0f90c6ff46..a9ed71892f 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -7,61 +7,41 @@ from base64 import standard_b64decode # copy_reg module. from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import dumps, loads, PickleSerializer -import cPickle +from pyspark.serializers import write_with_length, read_with_length, \ + dump_pickle, load_pickle + # Redirect stdout to stderr so that users must return values from functions. old_stdout = sys.stdout sys.stdout = sys.stderr -def load_function(): - return cPickle.loads(standard_b64decode(sys.stdin.readline().strip())) - - -def output(x): - dumps(x, old_stdout) +def load_obj(): + return load_pickle(standard_b64decode(sys.stdin.readline().strip())) def read_input(): try: while True: - yield cPickle.loads(loads(sys.stdin)) + yield load_pickle(read_with_length(sys.stdin)) except EOFError: return -def do_pipeline(): - f = load_function() - for obj in f(read_input()): - output(PickleSerializer.dumps(obj)) - - -def do_shuffle_map_step(): - hashFunc = load_function() - while True: - try: - pickled = loads(sys.stdin) - except EOFError: - return - key = cPickle.loads(pickled)[0] - output(str(hashFunc(key))) - output(pickled) - - def main(): num_broadcast_variables = int(sys.stdin.readline().strip()) for _ in range(num_broadcast_variables): uuid = sys.stdin.read(36) - value = loads(sys.stdin) - _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value)) - command = sys.stdin.readline().strip() - if command == "pipeline": - do_pipeline() - elif command == "shuffle_map_step": - do_shuffle_map_step() + value = read_with_length(sys.stdin) + _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value)) + func = load_obj() + bypassSerializer = load_obj() + if bypassSerializer: + dumps = lambda x: x else: - raise Exception("Unsupported command %s" % command) + dumps = dump_pickle + for obj in func(read_input()): + write_with_length(dumps(obj), old_stdout) if __name__ == '__main__': -- cgit v1.2.3 From bff6a46359131a8f9bc38b93149b22baa7c711cd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 18:00:25 -0700 Subject: Add pipe(), saveAsTextFile(), sc.union() to Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 8 +++++-- pyspark/pyspark/context.py | 14 ++++++------ pyspark/pyspark/rdd.py | 25 ++++++++++++++++++++-- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index b9091fd436..4d3bdb3963 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -9,6 +9,7 @@ import spark._ import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import broadcast.Broadcast import scala.collection +import java.nio.charset.Charset trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], @@ -238,9 +239,12 @@ private object Pickle { val MARK : Byte = '(' val APPENDS : Byte = 'e' } -class ExtractValue extends spark.api.java.function.Function[(Array[Byte], - Array[Byte]), Array[Byte]] { +private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], + Array[Byte]), Array[Byte]] { override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 +} +private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { + override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index b8490019e3..04932c93f2 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -7,6 +7,8 @@ from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length from pyspark.rdd import RDD +from py4j.java_collections import ListConverter + class SparkContext(object): @@ -39,12 +41,6 @@ class SparkContext(object): self._jsc = None def parallelize(self, c, numSlices=None): - """ - >>> sc = SparkContext("local", "test") - >>> rdd = sc.parallelize([(1, 2), (3, 4)]) - >>> rdd.collect() - [(1, 2), (3, 4)] - """ numSlices = numSlices or self.defaultParallelism # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized @@ -62,6 +58,12 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def union(self, rdds): + first = rdds[0]._jrdd + rest = [x._jrdd for x in rdds[1:]] + rest = ListConverter().convert(rest, self.gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self) + def broadcast(self, value): jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 21e822ba9f..8477f6dd02 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,6 +1,9 @@ from base64 import standard_b64encode as b64enc from collections import Counter from itertools import chain, ifilter, imap +import shlex +from subprocess import Popen, PIPE +from threading import Thread from pyspark import cloudpickle from pyspark.serializers import dump_pickle, load_pickle @@ -118,7 +121,20 @@ class RDD(object): """ return self.map(lambda x: (f(x), x)).groupByKey(numSplits) - # TODO: pipe + def pipe(self, command, env={}): + """ + >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() + ['1', '2', '3'] + """ + def func(iterator): + pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) + def pipe_objs(out): + for obj in iterator: + out.write(str(obj).rstrip('\n') + '\n') + out.close() + Thread(target=pipe_objs, args=[pipe.stdin]).start() + return (x.rstrip('\n') for x in pipe.stdout) + return self.mapPartitions(func) def foreach(self, f): """ @@ -206,7 +222,12 @@ class RDD(object): """ return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) - # TODO: saveAsTextFile + def saveAsTextFile(self, path): + def func(iterator): + return (str(x).encode("utf-8") for x in iterator) + keyed = PipelinedRDD(self, func) + keyed._bypass_serializer = True + keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) # Pair functions -- cgit v1.2.3 From 414367850982c4f8fc5e63cc94caa422eb736db5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 27 Aug 2012 00:13:19 -0700 Subject: Fix minor bugs in Python API examples. --- pyspark/pyspark/examples/pi.py | 2 +- pyspark/pyspark/examples/tc.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py index fe63d2c952..348bbc5dce 100644 --- a/pyspark/pyspark/examples/pi.py +++ b/pyspark/pyspark/examples/pi.py @@ -9,7 +9,7 @@ if __name__ == "__main__": print >> sys.stderr, \ "Usage: PythonPi []" exit(-1) - sc = SparkContext(sys.argv[1], "PythonKMeans") + sc = SparkContext(sys.argv[1], "PythonPi") slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 n = 100000 * slices def f(_): diff --git a/pyspark/pyspark/examples/tc.py b/pyspark/pyspark/examples/tc.py index 2796fdc6ad..9630e72b47 100644 --- a/pyspark/pyspark/examples/tc.py +++ b/pyspark/pyspark/examples/tc.py @@ -22,9 +22,9 @@ if __name__ == "__main__": print >> sys.stderr, \ "Usage: PythonTC []" exit(-1) - sc = SparkContext(sys.argv[1], "PythonKMeans") + sc = SparkContext(sys.argv[1], "PythonTC") slices = sys.argv[2] if len(sys.argv) > 2 else 2 - tc = sc.parallelizePairs(generateGraph(), slices).cache() + tc = sc.parallelize(generateGraph(), slices).cache() # Linear transitive closure: each round grows paths by one edge, # by joining the graph's edges with the already-discovered paths. @@ -32,7 +32,7 @@ if __name__ == "__main__": # the graph to obtain the path (x, z). # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.mapPairs(lambda (x, y): (y, x)) + edges = tc.map(lambda (x, y): (y, x)) oldCount = 0L nextCount = tc.count() @@ -40,7 +40,7 @@ if __name__ == "__main__": oldCount = nextCount # Perform the join, obtaining an RDD of (y, (z, x)) pairs, # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).mapPairs(lambda (_, (a, b)): (b, a)) + new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) tc = tc.union(new_edges).distinct().cache() nextCount = tc.count() if nextCount == oldCount: -- cgit v1.2.3 From b4a2214218eeb9ebd95b59d88c2212fe717efd9e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 27 Aug 2012 22:49:29 -0700 Subject: More fault tolerance fixes to catch lost tasks --- core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala | 2 ++ .../main/scala/spark/scheduler/cluster/TaskSetManager.scala | 13 +++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index 0fc1d8ed30..65e59841a9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -20,6 +20,8 @@ class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: def successful: Boolean = finished && !failed + def running: Boolean = !finished + def duration: Long = { if (!finished) { throw new UnsupportedOperationException("duration() called on unfinished tasks") diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 5412e8d8c0..17317e80df 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -265,6 +265,11 @@ class TaskSetManager( def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } val index = info.index info.markFailed() if (!finished(index)) { @@ -341,7 +346,7 @@ class TaskSetManager( } def hostLost(hostname: String) { - logInfo("Re-queueing tasks for " + hostname) + logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id) // If some task has preferred locations only on hostname, put it in the no-prefs list // to avoid the wait from delay scheduling for (index <- getPendingTasksForHost(hostname)) { @@ -350,7 +355,7 @@ class TaskSetManager( pendingTasksWithNoPrefs += index } } - // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage if (tasks(0).isInstanceOf[ShuffleMapTask]) { for ((tid, info) <- taskInfos if info.host == hostname) { val index = taskInfos(tid).index @@ -365,6 +370,10 @@ class TaskSetManager( } } } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.host == hostname) { + taskLost(tid, TaskState.KILLED, null) + } } /** -- cgit v1.2.3 From 17af2df0cdcdc4f02013bd7b4351e0a9d9ee9b25 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 27 Aug 2012 23:07:32 -0700 Subject: Log levels --- core/src/main/scala/spark/MapOutputTracker.scala | 2 +- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index e249430905..de23eb6f48 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -158,7 +158,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def incrementGeneration() { generationLock.synchronized { generation += 1 - logInfo("Increasing generation to " + generation) + logDebug("Increasing generation to " + generation) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 17317e80df..5a7df6040c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -88,7 +88,7 @@ class TaskSetManager( // Figure out the current map output tracker generation and set it on all tasks val generation = sched.mapOutputTracker.getGeneration - logInfo("Generation for " + taskSet.id + ": " + generation) + logDebug("Generation for " + taskSet.id + ": " + generation) for (t <- tasks) { t.generation = generation } -- cgit v1.2.3 From b5b93a621c6386929cee97e799d18b29a53bafbc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 28 Aug 2012 12:35:19 -0700 Subject: Added capabllity to take streaming input from network. Renamed SparkStreamContext to StreamingContext. --- .../spark/streaming/ConstantInputDStream.scala | 2 +- .../src/main/scala/spark/streaming/DStream.scala | 10 +- .../scala/spark/streaming/FileInputDStream.scala | 84 +------ .../main/scala/spark/streaming/JobManager.scala | 4 +- .../spark/streaming/NetworkInputDStream.scala | 36 +++ .../spark/streaming/NetworkInputReceiver.scala | 248 +++++++++++++++++++++ .../spark/streaming/NetworkInputTracker.scala | 110 +++++++++ .../spark/streaming/PairDStreamFunctions.scala | 2 +- .../scala/spark/streaming/QueueInputDStream.scala | 2 +- .../spark/streaming/ReducedWindowedDStream.scala | 2 +- .../src/main/scala/spark/streaming/Scheduler.scala | 3 +- .../scala/spark/streaming/SparkStreamContext.scala | 173 -------------- .../scala/spark/streaming/StreamingContext.scala | 192 ++++++++++++++++ .../src/main/scala/spark/streaming/Time.scala | 2 +- .../scala/spark/streaming/WindowedDStream.scala | 2 +- .../spark/streaming/examples/ExampleOne.scala | 6 +- .../spark/streaming/examples/ExampleTwo.scala | 6 +- .../scala/spark/streaming/examples/WordCount.scala | 25 --- .../spark/streaming/examples/WordCountHdfs.scala | 26 +++ .../streaming/examples/WordCountNetwork.scala | 25 +++ .../test/scala/spark/streaming/DStreamSuite.scala | 4 +- 21 files changed, 662 insertions(+), 302 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SparkStreamContext.scala create mode 100644 streaming/src/main/scala/spark/streaming/StreamingContext.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala index 6a2be34633..9bc204dd09 100644 --- a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala @@ -5,7 +5,7 @@ import spark.RDD /** * An input stream that always returns the same RDD on each timestep. Useful for testing. */ -class ConstantInputDStream[T: ClassManifest](ssc: SparkStreamContext, rdd: RDD[T]) +class ConstantInputDStream[T: ClassManifest](ssc: StreamingContext, rdd: RDD[T]) extends InputDStream[T](ssc) { override def start() {} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index c63c043415..0e45acabc3 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -1,6 +1,6 @@ package spark.streaming -import spark.streaming.SparkStreamContext._ +import spark.streaming.StreamingContext._ import spark.RDD import spark.BlockRDD @@ -15,7 +15,7 @@ import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue -abstract class DStream[T: ClassManifest] (@transient val ssc: SparkStreamContext) +abstract class DStream[T: ClassManifest] (@transient val ssc: StreamingContext) extends Logging with Serializable { initLogging() @@ -142,7 +142,7 @@ extends Logging with Serializable { } /** - * This method generates a SparkStream job for the given time + * This method generates a SparkStreaming job for the given time * and may require to be overriden by subclasses */ def generateJob(time: Time): Option[Job] = { @@ -249,7 +249,7 @@ extends Logging with Serializable { abstract class InputDStream[T: ClassManifest] ( - ssc: SparkStreamContext) + @transient ssc: StreamingContext) extends DStream[T](ssc) { override def dependencies = List() @@ -397,7 +397,7 @@ extends DStream[T](parents(0).ssc) { } if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different SparkStreamContexts") + throw new IllegalArgumentException("Array of parents have different StreamingContexts") } if (parents.map(_.slideTime).distinct.size > 1) { diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala index 88aa375289..96a64f0018 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -19,7 +19,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - ssc: SparkStreamContext, + ssc: StreamingContext, directory: Path, filter: PathFilter = FileInputDStream.defaultPathFilter, newFilesOnly: Boolean = true) @@ -28,7 +28,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K val fs = directory.getFileSystem(new Configuration()) var lastModTime: Long = 0 - override def start() { + override def start() { if (newFilesOnly) { lastModTime = System.currentTimeMillis() } else { @@ -82,83 +82,3 @@ object FileInputDStream { } } -/* -class NetworkInputDStream[T: ClassManifest]( - val networkInputName: String, - val addresses: Array[InetSocketAddress], - batchDuration: Time, - ssc: SparkStreamContext) -extends InputDStream[T](networkInputName, batchDuration, ssc) { - - - // TODO(Haoyuan): This is for the performance test. - @transient var rdd: RDD[T] = null - - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Running initial count to cache fake RDD") - rdd = ssc.sc.textFile(SparkContext.inputFile, - SparkContext.idealPartitions).asInstanceOf[RDD[T]] - val fakeCacheLevel = System.getProperty("spark.fake.cache", "") - if (fakeCacheLevel == "MEMORY_ONLY_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel == "MEMORY_ONLY_DESER_2") { - rdd.persist(StorageLevel.MEMORY_ONLY_2) - } else if (fakeCacheLevel != "") { - logError("Invalid fake cache level: " + fakeCacheLevel) - System.exit(1) - } - rdd.count() - } - - @transient val references = new HashMap[Time,String] - - override def compute(validTime: Time): Option[RDD[T]] = { - if (System.getProperty("spark.fake", "false") == "true") { - logInfo("Returning fake RDD at " + validTime) - return Some(rdd) - } - references.get(validTime) match { - case Some(reference) => - if (reference.startsWith("file") || reference.startsWith("hdfs")) { - logInfo("Reading from file " + reference + " for time " + validTime) - Some(ssc.sc.textFile(reference).asInstanceOf[RDD[T]]) - } else { - logInfo("Getting from BlockManager " + reference + " for time " + validTime) - Some(new BlockRDD(ssc.sc, Array(reference))) - } - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.toString)) - logInfo("Reference added for time " + time + " - " + reference.toString) - } -} - - -class TestInputDStream( - val testInputName: String, - batchDuration: Time, - ssc: SparkStreamContext) -extends InputDStream[String](testInputName, batchDuration, ssc) { - - @transient val references = new HashMap[Time,Array[String]] - - override def compute(validTime: Time): Option[RDD[String]] = { - references.get(validTime) match { - case Some(reference) => - Some(new BlockRDD[String](ssc.sc, reference)) - case None => - throw new Exception(this.toString + ": Reference missing for time " + validTime + "!!!") - None - } - } - - def setReference(time: Time, reference: AnyRef) { - references += ((time, reference.asInstanceOf[Array[String]])) - } -} -*/ diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index c37fe1e9ad..2a4fe3dd11 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -5,9 +5,9 @@ import spark.SparkEnv import java.util.concurrent.Executors -class JobManager(ssc: SparkStreamContext, numThreads: Int = 1) extends Logging { +class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { - class JobHandler(ssc: SparkStreamContext, job: Job) extends Runnable { + class JobHandler(ssc: StreamingContext, job: Job) extends Runnable { def run() { SparkEnv.set(ssc.env) try { diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala new file mode 100644 index 0000000000..ee09324c8c --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -0,0 +1,36 @@ +package spark.streaming + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ + +import spark.RDD +import spark.BlockRDD +import spark.Logging + +import java.io.InputStream + + +class NetworkInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + val host: String, + val port: Int, + val bytesToObjects: InputStream => Iterator[T] + ) extends InputDStream[T](ssc) with Logging { + + val id = ssc.getNewNetworkStreamId() + + def start() { } + + def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + return Some(new BlockRDD[T](ssc.sc, blockIds)) + } + + def createReceiver(): NetworkInputReceiver[T] = { + new NetworkInputReceiver(id, host, port, bytesToObjects) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala b/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala new file mode 100644 index 0000000000..7add6246b7 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala @@ -0,0 +1,248 @@ +package spark.streaming + +import spark.Logging +import spark.storage.BlockManager +import spark.storage.StorageLevel +import spark.SparkEnv +import spark.streaming.util.SystemClock +import spark.streaming.util.RecurringTimer + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue +import scala.collection.mutable.SynchronizedPriorityQueue +import scala.math.Ordering + +import java.net.InetSocketAddress +import java.net.Socket +import java.io.InputStream +import java.io.BufferedInputStream +import java.io.DataInputStream +import java.io.EOFException +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ArrayBlockingQueue + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ + +trait NetworkInputReceiverMessage +case class GetBlockIds(time: Long) extends NetworkInputReceiverMessage +case class GotBlockIds(streamId: Int, blocksIds: Array[String]) extends NetworkInputReceiverMessage +case class StopReceiver() extends NetworkInputReceiverMessage + +class NetworkInputReceiver[T: ClassManifest](streamId: Int, host: String, port: Int, bytesToObjects: InputStream => Iterator[T]) +extends Logging { + + class ReceiverActor extends Actor { + override def preStart() = { + logInfo("Attempting to register") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 100.milliseconds + val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + def receive = { + case GetBlockIds(time) => { + logInfo("Got request for block ids for " + time) + sender ! GotBlockIds(streamId, dataHandler.getPushedBlocks()) + } + + case StopReceiver() => { + if (receivingThread != null) { + receivingThread.interrupt() + } + sender ! true + } + } + } + + class DataHandler { + + class Block(val time: Long, val iterator: Iterator[T]) { + val blockId = "input-" + streamId + "-" + time + var pushed = true + override def toString() = "input block " + blockId + } + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockOrdering = new Ordering[Block] { + def compare(b1: Block, b2: Block) = (b1.time - b2.time).toInt + } + val blockStorageLevel = StorageLevel.DISK_AND_MEMORY + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blocksForReporting = new SynchronizedPriorityQueue[Block]()(blockOrdering) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val newBlock = new Block(time - blockInterval, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + blocksForReporting.enqueue(newBlock) + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + if (blockManager != null) { + blockManager.put(block.blockId, block.iterator, blockStorageLevel) + block.pushed = true + } else { + logWarning(block + " not put as block manager is null") + } + } + } catch { + case ie: InterruptedException => println("Block pushing thread interrupted") + case e: Exception => e.printStackTrace() + } + } + + def getPushedBlocks(): Array[String] = { + val pushedBlocks = new ArrayBuffer[String]() + var loop = true + while(loop && !blocksForReporting.isEmpty) { + val block = blocksForReporting.dequeue() + if (block == null) { + loop = false + } else if (!block.pushed) { + blocksForReporting.enqueue(block) + } else { + pushedBlocks += block.blockId + } + } + logInfo("Got " + pushedBlocks.size + " blocks") + pushedBlocks.toArray + } + } + + val blockManager = if (SparkEnv.get != null) SparkEnv.get.blockManager else null + val dataHandler = new DataHandler() + val env = SparkEnv.get + + var receiverActor: ActorRef = null + var receivingThread: Thread = null + + def run() { + initLogging() + var socket: Socket = null + try { + if (SparkEnv.get != null) { + receiverActor = SparkEnv.get.actorSystem.actorOf(Props(new ReceiverActor), "ReceiverActor-" + streamId) + } + dataHandler.start() + socket = connect() + receivingThread = Thread.currentThread() + receive(socket) + } catch { + case ie: InterruptedException => logInfo("Receiver interrupted") + } finally { + receivingThread = null + if (socket != null) socket.close() + dataHandler.stop() + } + } + + def connect(): Socket = { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + socket + } + + def receive(socket: Socket) { + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } +} + + +object NetworkInputReceiver { + + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val bufferedInputStream = new BufferedInputStream(inputStream) + val dataInputStream = new DataInputStream(bufferedInputStream) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + println("[" + nextValue + "]") + } catch { + case eof: EOFException => + finished = true + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!gotNext) { + getNext() + } + if (finished) { + dataInputStream.close() + } + !finished + } + + + override def next(): String = { + if (!gotNext) { + getNext() + } + if (finished) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } + } + iterator + } + + def main(args: Array[String]) { + if (args.length < 2) { + println("NetworkReceiver ") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val receiver = new NetworkInputReceiver(0, host, port, bytesToLines) + receiver.run() + } +} diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala new file mode 100644 index 0000000000..07758665c9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -0,0 +1,110 @@ +package spark.streaming + +import spark.Logging +import spark.SparkEnv + +import scala.collection.mutable.HashMap + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ + +trait NetworkInputTrackerMessage +case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage + +class NetworkInputTracker( + @transient ssc: StreamingContext, + @transient networkInputStreams: Array[NetworkInputDStream[_]]) +extends Logging { + + class TrackerActor extends Actor { + def receive = { + case RegisterReceiver(streamId, receiverActor) => { + if (!networkInputStreamIds.contains(streamId)) { + throw new Exception("Register received for unexpected id " + streamId) + } + receiverInfo += ((streamId, receiverActor)) + logInfo("Registered receiver for network stream " + streamId) + sender ! true + } + } + } + + class ReceiverExecutor extends Thread { + val env = ssc.env + + override def run() { + try { + SparkEnv.set(env) + startReceivers() + } catch { + case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") + } finally { + stopReceivers() + } + } + + def startReceivers() { + val tempRDD = ssc.sc.makeRDD(networkInputStreams, networkInputStreams.size) + + val startReceiver = (iterator: Iterator[NetworkInputDStream[_]]) => { + if (!iterator.hasNext) { + throw new Exception("Could not start receiver as details not found.") + } + val stream = iterator.next + val receiver = stream.createReceiver() + receiver.run() + } + + ssc.sc.runJob(tempRDD, startReceiver) + } + + def stopReceivers() { + implicit val ec = env.actorSystem.dispatcher + val message = new StopReceiver() + val listOfFutures = receiverInfo.values.map(_.ask(message)(timeout)).toList + val futureOfList = Future.sequence(listOfFutures) + Await.result(futureOfList, timeout) + } + } + + val networkInputStreamIds = networkInputStreams.map(_.id).toArray + val receiverExecutor = new ReceiverExecutor() + val receiverInfo = new HashMap[Int, ActorRef] + val receivedBlockIds = new HashMap[Int, Array[String]] + val timeout = 1000.milliseconds + + + var currentTime: Time = null + + def start() { + ssc.env.actorSystem.actorOf(Props(new TrackerActor), "NetworkInputTracker") + receiverExecutor.start() + } + + def stop() { + // stop the actor + receiverExecutor.interrupt() + } + + def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { + if (currentTime == null || time > currentTime) { + logInfo("Getting block ids from receivers for " + time) + implicit val ec = ssc.env.actorSystem.dispatcher + receivedBlockIds.clear() + val message = new GetBlockIds(time) + val listOfFutures = receiverInfo.values.map( + _.ask(message)(timeout).mapTo[GotBlockIds] + ).toList + val futureOfList = Future.sequence(listOfFutures) + val allBlockIds = Await.result(futureOfList, timeout) + receivedBlockIds ++= allBlockIds.map(x => (x.streamId, x.blocksIds)) + if (receivedBlockIds.size != receiverInfo.size) { + throw new Exception("Unexpected number of the Block IDs received") + } + currentTime = time + } + receivedBlockIds.getOrElse(receiverId, Array[String]()) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 0cf296f21a..d2887c3aea 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -1,7 +1,7 @@ package spark.streaming import scala.collection.mutable.ArrayBuffer -import spark.streaming.SparkStreamContext._ +import spark.streaming.StreamingContext._ class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) extends Serializable { diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala index c78abd1a87..bab48ff954 100644 --- a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer class QueueInputDStream[T: ClassManifest]( - ssc: SparkStreamContext, + ssc: StreamingContext, val queue: Queue[RDD[T]], oneAtATime: Boolean, defaultRDD: RDD[T] diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 11fa4e5443..9d48870877 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -1,6 +1,6 @@ package spark.streaming -import spark.streaming.SparkStreamContext._ +import spark.streaming.StreamingContext._ import spark.RDD import spark.UnionRDD diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 309bd95525..da50e22719 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -12,7 +12,7 @@ sealed trait SchedulerMessage case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage class Scheduler( - ssc: SparkStreamContext, + ssc: StreamingContext, inputStreams: Array[InputDStream[_]], outputStreams: Array[DStream[_]]) extends Logging { @@ -40,6 +40,7 @@ extends Logging { } def generateRDDs (time: Time) { + println("\n-----------------------------------------------------\n") logInfo("Generating RDDs for time " + time) outputStreams.foreach(outputStream => { outputStream.generateJob(time) match { diff --git a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala b/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala deleted file mode 100644 index 2bec1091c0..0000000000 --- a/streaming/src/main/scala/spark/streaming/SparkStreamContext.scala +++ /dev/null @@ -1,173 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.Logging -import spark.SparkEnv -import spark.SparkContext - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Queue - -import java.io.IOException -import java.net.InetSocketAddress -import java.util.concurrent.atomic.AtomicInteger - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat - -class SparkStreamContext ( - master: String, - frameworkName: String, - val sparkHome: String = null, - val jars: Seq[String] = Nil) - extends Logging { - - initLogging() - - val sc = new SparkContext(master, frameworkName, sparkHome, jars) - val env = SparkEnv.get - - val inputStreams = new ArrayBuffer[InputDStream[_]]() - val outputStreams = new ArrayBuffer[DStream[_]]() - var batchDuration: Time = null - var scheduler: Scheduler = null - - def setBatchDuration(duration: Long) { - setBatchDuration(Time(duration)) - } - - def setBatchDuration(duration: Time) { - batchDuration = duration - } - - /* - def createNetworkStream[T: ClassManifest]( - name: String, - addresses: Array[InetSocketAddress], - batchDuration: Time): DStream[T] = { - - val inputStream = new NetworkinputStream[T](this, addresses) - inputStreams += inputStream - inputStream - } - - def createNetworkStream[T: ClassManifest]( - name: String, - addresses: Array[String], - batchDuration: Long): DStream[T] = { - - def stringToInetSocketAddress (str: String): InetSocketAddress = { - val parts = str.split(":") - if (parts.length != 2) { - throw new IllegalArgumentException ("Address format error") - } - new InetSocketAddress(parts(0), parts(1).toInt) - } - - readNetworkStream( - name, - addresses.map(stringToInetSocketAddress).toArray, - LongTime(batchDuration)) - } - */ - - /** - * This function creates a input stream that monitors a Hadoop-compatible - * for new files and executes the necessary processing on them. - */ - def createFileStream[ - K: ClassManifest, - V: ClassManifest, - F <: NewInputFormat[K, V]: ClassManifest - ](directory: String): DStream[(K, V)] = { - val inputStream = new FileInputDStream[K, V, F](this, new Path(directory)) - inputStreams += inputStream - inputStream - } - - def createTextFileStream(directory: String): DStream[String] = { - createFileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) - } - - /** - * This function create a input stream from an queue of RDDs. In each batch, - * it will process either one or all of the RDDs returned by the queue - */ - def createQueueStream[T: ClassManifest]( - queue: Queue[RDD[T]], - oneAtATime: Boolean = true, - defaultRDD: RDD[T] = null - ): DStream[T] = { - val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) - inputStreams += inputStream - inputStream - } - - def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): DStream[T] = { - val queue = new Queue[RDD[T]] - val inputStream = createQueueStream(queue, true, null) - queue ++= iterator - inputStream - } - - - /** - * This function registers a DStream as an output stream that will be - * computed every interval. - */ - def registerOutputStream (outputStream: DStream[_]) { - outputStreams += outputStream - } - - /** - * This function verify whether the stream computation is eligible to be executed. - */ - def verify() { - if (batchDuration == null) { - throw new Exception("Batch duration has not been set") - } - if (batchDuration < Milliseconds(100)) { - logWarning("Batch duration of " + batchDuration + " is very low") - } - if (inputStreams.size == 0) { - throw new Exception("No input streams created, so nothing to take input from") - } - if (outputStreams.size == 0) { - throw new Exception("No output streams registered, so nothing to execute") - } - - } - - /** - * This function starts the execution of the streams. - */ - def start() { - verify() - scheduler = new Scheduler(this, inputStreams.toArray, outputStreams.toArray) - scheduler.start() - } - - /** - * This function starts the execution of the streams. - */ - def stop() { - try { - scheduler.stop() - sc.stop() - } catch { - case e: Exception => logWarning("Error while stopping", e) - } - - logInfo("SparkStreamContext stopped") - } -} - - -object SparkStreamContext { - implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = - new PairDStreamFunctions(stream) -} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala new file mode 100644 index 0000000000..0ac86cbdf2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -0,0 +1,192 @@ +package spark.streaming + +import spark.RDD +import spark.Logging +import spark.SparkEnv +import spark.SparkContext + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue + +import java.io.InputStream +import java.io.IOException +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat + +class StreamingContext ( + master: String, + frameworkName: String, + val sparkHome: String = null, + val jars: Seq[String] = Nil) + extends Logging { + + initLogging() + + val sc = new SparkContext(master, frameworkName, sparkHome, jars) + val env = SparkEnv.get + + val inputStreams = new ArrayBuffer[InputDStream[_]]() + val outputStreams = new ArrayBuffer[DStream[_]]() + val nextNetworkInputStreamId = new AtomicInteger(0) + + var batchDuration: Time = null + var scheduler: Scheduler = null + var networkInputTracker: NetworkInputTracker = null + var receiverJobThread: Thread = null + + def setBatchDuration(duration: Long) { + setBatchDuration(Time(duration)) + } + + def setBatchDuration(duration: Time) { + batchDuration = duration + } + + private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + + def createNetworkTextStream(hostname: String, port: Int): DStream[String] = { + createNetworkStream[String](hostname, port, NetworkInputReceiver.bytesToLines) + } + + def createNetworkStream[T: ClassManifest]( + hostname: String, + port: Int, + converter: (InputStream) => Iterator[T] + ): DStream[T] = { + val inputStream = new NetworkInputDStream[T](this, hostname, port, converter) + inputStreams += inputStream + inputStream + } + + /* + def createHttpTextStream(url: String): DStream[String] = { + createHttpStream(url, NetworkInputReceiver.bytesToLines) + } + + def createHttpStream[T: ClassManifest]( + url: String, + converter: (InputStream) => Iterator[T] + ): DStream[T] = { + } + */ + + /** + * This function creates a input stream that monitors a Hadoop-compatible + * for new files and executes the necessary processing on them. + */ + def createFileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ](directory: String): DStream[(K, V)] = { + val inputStream = new FileInputDStream[K, V, F](this, new Path(directory)) + inputStreams += inputStream + inputStream + } + + def createTextFileStream(directory: String): DStream[String] = { + createFileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) + } + + /** + * This function create a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue + */ + def createQueueStream[T: ClassManifest]( + queue: Queue[RDD[T]], + oneAtATime: Boolean = true, + defaultRDD: RDD[T] = null + ): DStream[T] = { + val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) + inputStreams += inputStream + inputStream + } + + def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): DStream[T] = { + val queue = new Queue[RDD[T]] + val inputStream = createQueueStream(queue, true, null) + queue ++= iterator + inputStream + } + + + /** + * This function registers a DStream as an output stream that will be + * computed every interval. + */ + def registerOutputStream (outputStream: DStream[_]) { + outputStreams += outputStream + } + + /** + * This function verify whether the stream computation is eligible to be executed. + */ + private def verify() { + if (batchDuration == null) { + throw new Exception("Batch duration has not been set") + } + if (batchDuration < Milliseconds(100)) { + logWarning("Batch duration of " + batchDuration + " is very low") + } + if (inputStreams.size == 0) { + throw new Exception("No input streams created, so nothing to take input from") + } + if (outputStreams.size == 0) { + throw new Exception("No output streams registered, so nothing to execute") + } + + } + + /** + * This function starts the execution of the streams. + */ + def start() { + verify() + val networkInputStreams = inputStreams.filter(s => s match { + case n: NetworkInputDStream[_] => true + case _ => false + }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray + + if (networkInputStreams.length > 0) { + // Start the network input tracker (must start before receivers) + networkInputTracker = new NetworkInputTracker(this, networkInputStreams) + networkInputTracker.start() + } + + Thread.sleep(1000) + // Start the scheduler + scheduler = new Scheduler(this, inputStreams.toArray, outputStreams.toArray) + scheduler.start() + } + + /** + * This function starts the execution of the streams. + */ + def stop() { + try { + if (scheduler != null) scheduler.stop() + if (networkInputTracker != null) networkInputTracker.stop() + if (receiverJobThread != null) receiverJobThread.interrupt() + sc.stop() + } catch { + case e: Exception => logWarning("Error while stopping", e) + } + + logInfo("StreamingContext stopped") + } +} + + +object StreamingContext { + implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { + new PairDStreamFunctions[K, V](stream) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 5c476f02c3..1db95a45e1 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,6 +1,6 @@ package spark.streaming -class Time(private var millis: Long) { +class Time (private var millis: Long) extends Serializable { def copy() = new Time(this.millis) diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala index 9a6617a1ee..bf11cbcec0 100644 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -1,6 +1,6 @@ package spark.streaming -import spark.streaming.SparkStreamContext._ +import spark.streaming.StreamingContext._ import spark.RDD import spark.UnionRDD diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala index 669f575240..2ff8790e77 100644 --- a/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala @@ -1,8 +1,8 @@ package spark.streaming.examples import spark.RDD -import spark.streaming.SparkStreamContext -import spark.streaming.SparkStreamContext._ +import spark.streaming.StreamingContext +import spark.streaming.StreamingContext._ import spark.streaming.Seconds import scala.collection.mutable.SynchronizedQueue @@ -16,7 +16,7 @@ object ExampleOne { } // Create the context and set the batch size - val ssc = new SparkStreamContext(args(0), "ExampleOne") + val ssc = new StreamingContext(args(0), "ExampleOne") ssc.setBatchDuration(Seconds(1)) // Create the queue through which RDDs can be pushed to diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala index be47e47a5a..ad563e2c75 100644 --- a/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala +++ b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala @@ -1,7 +1,7 @@ package spark.streaming.examples -import spark.streaming.SparkStreamContext -import spark.streaming.SparkStreamContext._ +import spark.streaming.StreamingContext +import spark.streaming.StreamingContext._ import spark.streaming.Seconds import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration @@ -15,7 +15,7 @@ object ExampleTwo { } // Create the context and set the batch size - val ssc = new SparkStreamContext(args(0), "ExampleTwo") + val ssc = new StreamingContext(args(0), "ExampleTwo") ssc.setBatchDuration(Seconds(2)) // Create the new directory diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount.scala deleted file mode 100644 index ba7bc63d6a..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.{Seconds, SparkStreamContext} -import spark.streaming.SparkStreamContext._ - -object WordCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCount ") - System.exit(1) - } - - // Create the context and set the batch size - val ssc = new SparkStreamContext(args(0), "ExampleTwo") - ssc.setBatchDuration(Seconds(2)) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val lines = ssc.createTextFileStream(args(1)) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala new file mode 100644 index 0000000000..3b86948822 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala @@ -0,0 +1,26 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +object WordCountHdfs { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: WordCountHdfs ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "WordCountHdfs") + ssc.setBatchDuration(Seconds(2)) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val lines = ssc.createTextFileStream(args(1)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} + diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala new file mode 100644 index 0000000000..0a33a05bae --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala @@ -0,0 +1,25 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +object WordCountNetwork { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: WordCountNetwork ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "WordCountNetwork") + ssc.setBatchDuration(Seconds(2)) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val lines = ssc.createNetworkTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala index 2c10a03e6d..e9dc377263 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala @@ -12,7 +12,7 @@ import scala.collection.mutable.SynchronizedQueue class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { - var ssc: SparkStreamContext = null + var ssc: StreamingContext = null val batchDurationMillis = 1000 System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") @@ -22,7 +22,7 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { operation: DStream[U] => DStream[V], expectedOutput: Seq[Seq[V]]) { try { - ssc = new SparkStreamContext("local", "test") + ssc = new StreamingContext("local", "test") ssc.setBatchDuration(Milliseconds(batchDurationMillis)) val inputStream = ssc.createQueueStream(input.map(ssc.sc.makeRDD(_, 2)).toIterator) -- cgit v1.2.3 From 1f8085b8d097f060a3939eaff5be1f58111ec224 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Aug 2012 03:20:56 +0000 Subject: Compile fixes --- streaming/src/main/scala/spark/streaming/Time.scala | 2 +- streaming/src/main/scala/spark/streaming/examples/WordCount2.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index e44025138f..9ddb65249a 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,6 +1,6 @@ package spark.streaming -case class Time(millis: Long) extends Serializable { +case class Time(millis: Long) { def < (that: Time): Boolean = (this.millis < that.millis) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 1afe87e723..3b2d909584 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -3,7 +3,7 @@ package spark.streaming.examples import spark.SparkContext import SparkContext._ import spark.streaming._ -import SparkStreamContext._ +import StreamingContext._ import spark.storage.StorageLevel @@ -71,7 +71,7 @@ object WordCount2 { def main (args: Array[String]) { if (args.length != 5) { - println ("Usage: SparkStreamContext ") + println ("Usage: WordCount2 ") System.exit(1) } @@ -79,7 +79,7 @@ object WordCount2 { val BATCH_DURATION = Milliseconds(batchMillis.toLong) - val ssc = new SparkStreamContext(master, "WordCount2") + val ssc = new StreamingContext(master, "WordCount2") ssc.setBatchDuration(BATCH_DURATION) val data = ssc.sc.textFile(file, mapTasks.toInt).persist(StorageLevel.MEMORY_ONLY_DESER_2) -- cgit v1.2.3 From 4db3a967669a53de4c4b79b4c0b70daa5accb682 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 29 Aug 2012 13:04:01 -0700 Subject: Made minor changes to reduce compilation errors in Eclipse. Twirl stuff still does not compile in Eclipse. --- .../src/main/scala/spark/network/ConnectionManager.scala | 16 +++++++++++++--- .../main/scala/spark/network/ConnectionManagerTest.scala | 5 ++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 0e764fff81..2bb5f5fc6b 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -16,6 +16,7 @@ import scala.collection.mutable.ArrayBuffer import akka.dispatch.{Await, Promise, ExecutionContext, Future} import akka.util.Duration +import akka.util.duration._ case class ConnectionManagerId(host: String, port: Int) { def toSocketAddress() = new InetSocketAddress(host, port) @@ -403,7 +404,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis val mb = size * count / 1024.0 / 1024.0 @@ -430,7 +434,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis val ms = finishTime - startTime @@ -457,7 +464,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis Thread.sleep(1000) val mb = size * count / 1024.0 / 1024.0 diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 5d21bb793f..555b3454ee 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -8,6 +8,9 @@ import scala.io.Source import java.nio.ByteBuffer import java.net.InetAddress +import akka.dispatch.Await +import akka.util.duration._ + object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { if (args.length < 2) { @@ -53,7 +56,7 @@ object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => f()) + val results = futures.map(f => Await.result(f, 1.second)) val finishTime = System.currentTimeMillis Thread.sleep(5000) -- cgit v1.2.3 From c4366eb76425d1c6aeaa7df750a2681a0da75db8 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 31 Aug 2012 00:34:24 +0000 Subject: Fixes to ShuffleFetcher --- .../scala/spark/BlockStoreShuffleFetcher.scala | 41 +++++++++------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 45a14c8290..0bbdb4e432 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -32,36 +32,29 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId))) } - try { - for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { - blockOption match { - case Some(block) => { - val values = block - for(value <- values) { - val v = value.asInstanceOf[(K, V)] - func(v._1, v._2) - } - } - case None => { - throw new BlockException(blockId, "Did not get block " + blockId) + for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { + blockOption match { + case Some(block) => { + val values = block + for(value <- values) { + val v = value.asInstanceOf[(K, V)] + func(v._1, v._2) } } - } - } catch { - // TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException - case be: BlockException => { - val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r - be.blockId match { - case regex(sId, mId, rId) => { - val address = addresses(mId.toInt) - throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be) - } - case _ => { - throw be + case None => { + val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]*)".r + blockId match { + case regex(shufId, mapId, reduceId) => + val addr = addresses(mapId.toInt) + throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") } } } } + logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) } -- cgit v1.2.3 From d4d2cb670f9a19a6ab151e51398ba135445f1bb9 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 31 Aug 2012 00:34:57 +0000 Subject: Make checkpoint interval configurable in WordCount2 --- .../main/scala/spark/streaming/examples/WordCount2.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 3b2d909584..d4b7461099 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -70,17 +70,17 @@ object WordCount2 { def main (args: Array[String]) { - if (args.length != 5) { - println ("Usage: WordCount2 ") + if (args.length != 6) { + println ("Usage: WordCount2 ") System.exit(1) } - val Array(master, file, mapTasks, reduceTasks, batchMillis) = args + val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args - val BATCH_DURATION = Milliseconds(batchMillis.toLong) + val batchDuration = Milliseconds(batchMillis.toLong) val ssc = new StreamingContext(master, "WordCount2") - ssc.setBatchDuration(BATCH_DURATION) + ssc.setBatchDuration(batchDuration) val data = ssc.sc.textFile(file, mapTasks.toInt).persist(StorageLevel.MEMORY_ONLY_DESER_2) println("Data count: " + data.count()) @@ -94,8 +94,9 @@ object WordCount2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) - .reduceByKeyAndWindow(add _, subtract _, Seconds(10), BATCH_DURATION, reduceTasks.toInt) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Seconds(10)) + .reduceByKeyAndWindow(add _, subtract _, Seconds(10), batchDuration, reduceTasks.toInt) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, + Milliseconds(chkptMillis.toLong)) windowedCounts.print() ssc.start() -- cgit v1.2.3 From 1b3e3352ebfed40881d534cd3096d4b6428c24d4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 30 Aug 2012 17:59:25 -0700 Subject: Deserialize multi-get results in the caller's thread. This fixes an issue with shared buffers with the KryoSerializer. --- core/src/main/scala/spark/storage/BlockManager.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 45f99717bc..e9197f7169 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -272,11 +272,15 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val totalBlocks = blocksByAddress.map(_._2.size).sum logDebug("Getting " + totalBlocks + " blocks") var startTime = System.currentTimeMillis - val results = new LinkedBlockingQueue[(String, Option[Iterator[Any]])] val localBlockIds = new ArrayBuffer[String]() val remoteBlockIds = new ArrayBuffer[String]() val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]() + // A queue to hold our results. Because we want all the deserializing the happen in the + // caller's thread, this will actually hold functions to produce the Iterator for each block. + // For local blocks we'll have an iterator already, while for remote ones we'll deserialize. + val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])] + // Split local and remote blocks for ((address, blockIds) <- blocksByAddress) { if (address == blockManagerId) { @@ -302,10 +306,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new SparkException( "Unexpected message " + blockMessage.getType + " received from " + cmId) } - val buffer = blockMessage.getData val blockId = blockMessage.getId - val block = dataDeserialize(buffer) - results.put((blockId, Some(block))) + results.put((blockId, Some(() => dataDeserialize(blockMessage.getData)))) logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) }) } @@ -323,9 +325,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Get the local blocks while remote blocks are being fetched startTime = System.currentTimeMillis localBlockIds.foreach(id => { - get(id) match { + getLocal(id) match { case Some(block) => { - results.put((id, Some(block))) + results.put((id, Some(() => block))) logDebug("Got local block " + id) } case None => { @@ -343,7 +345,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 - results.take() + val (blockId, functionOption) = results.take() + (blockId, functionOption.map(_.apply())) } } } -- cgit v1.2.3 From 101ae493e26693146114ac01d50d411f5b2e0762 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 30 Aug 2012 22:24:14 -0700 Subject: Replicate serialized blocks properly, without sharing a ByteBuffer. --- core/src/main/scala/spark/storage/BlockManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index e9197f7169..8a013230da 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -456,7 +456,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // data is already serialized and ready for sending val replicationFuture = if (level.replication > 1) { Future { - replicate(blockId, bytes, level) + replicate(blockId, bytes.duplicate(), level) } } else { null -- cgit v1.2.3 From 113277549c5ee1bcd58c7cebc365d28d92b74b4a Mon Sep 17 00:00:00 2001 From: root Date: Fri, 31 Aug 2012 05:39:35 +0000 Subject: Really fixed the replication-3 issue. The problem was a few buffers not being rewound. --- core/src/main/scala/spark/storage/BlockManager.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 8a013230da..f2d9499bad 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -455,8 +455,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Initiate the replication before storing it locally. This is faster as // data is already serialized and ready for sending val replicationFuture = if (level.replication > 1) { + val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper Future { - replicate(blockId, bytes.duplicate(), level) + replicate(blockId, bufferView, level) } } else { null @@ -514,15 +515,16 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m var peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) for (peer: BlockManagerId <- peers) { val start = System.nanoTime + data.rewind() logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " - + data.array().length + " Bytes. To node: " + peer) + + data.limit() + " Bytes. To node: " + peer) if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), new ConnectionManagerId(peer.ip, peer.port))) { logError("Failed to call syncPutBlock to " + peer) } logDebug("Replicated BlockId " + blockId + " once used " + (System.nanoTime - start) / 1e6 + " s; The size of the data is " + - data.array().length + " bytes.") + data.limit() + " bytes.") } } -- cgit v1.2.3 From e1da274a486b6fd5903d9b3643dc07c79973b81d Mon Sep 17 00:00:00 2001 From: root Date: Fri, 31 Aug 2012 07:16:19 +0000 Subject: WordCount tweaks --- .../scala/spark/streaming/examples/WordCount2.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index d4b7461099..a090dcb85d 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -59,12 +59,12 @@ object WordCount2_ExtraFunctions { object WordCount2 { - def moreWarmup(sc: SparkContext) { - (0 until 40).foreach {i => + def warmup(sc: SparkContext) { + (0 until 10).foreach {i => sc.parallelize(1 to 20000000, 1000) - .map(_ % 1331).map(_.toString) - .mapPartitions(WordCount2_ExtraFunctions.splitAndCountPartitions).reduceByKey(_ + _, 10) - .collect() + .map(x => (x % 337, x % 1331)) + .reduceByKey(_ + _) + .count() } } @@ -82,7 +82,10 @@ object WordCount2 { val ssc = new StreamingContext(master, "WordCount2") ssc.setBatchDuration(batchDuration) - val data = ssc.sc.textFile(file, mapTasks.toInt).persist(StorageLevel.MEMORY_ONLY_DESER_2) + //warmup(ssc.sc) + + val data = ssc.sc.textFile(file, mapTasks.toInt).persist( + new StorageLevel(false, true, true, 2)) // Memory only, deserialized, 2 replicas println("Data count: " + data.count()) println("Data count: " + data.count()) println("Data count: " + data.count()) @@ -94,7 +97,7 @@ object WordCount2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) - .reduceByKeyAndWindow(add _, subtract _, Seconds(10), batchDuration, reduceTasks.toInt) + .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Milliseconds(chkptMillis.toLong)) windowedCounts.print() -- cgit v1.2.3 From 2d01d38a4199590145551a108903a3ac7cffcceb Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 31 Aug 2012 03:47:34 -0700 Subject: Added StateDStream, corresponding stateful stream operations, and testcases. Also refactored few PairDStreamFunctions methods. --- .../src/main/scala/spark/streaming/DStream.scala | 31 ++-- .../spark/streaming/PairDStreamFunctions.scala | 162 +++++++++++++++++---- .../spark/streaming/ReducedWindowedDStream.scala | 13 +- .../src/main/scala/spark/streaming/Scheduler.scala | 2 +- .../main/scala/spark/streaming/StateDStream.scala | 83 +++++++++++ .../streaming/examples/WordCountNetwork.scala | 4 +- .../test/scala/spark/streaming/DStreamSuite.scala | 90 +++++++++--- 7 files changed, 303 insertions(+), 82 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/StateDStream.scala diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 3a57488f9b..8c06345933 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -3,13 +3,12 @@ package spark.streaming import spark.streaming.StreamingContext._ import spark.RDD -import spark.BlockRDD import spark.UnionRDD import spark.Logging -import spark.SparkContext import spark.SparkContext._ import spark.storage.StorageLevel - +import spark.Partitioner + import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -95,12 +94,12 @@ extends Logging with Serializable { /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ private def isTimeValid (time: Time): Boolean = { - if (!isInitialized) + if (!isInitialized) { throw new Exception (this.toString + " has not been initialized") - if ((time - zeroTime).isMultipleOf(slideTime)) { - true - } else { + } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { false + } else { + true } } @@ -119,7 +118,7 @@ extends Logging with Serializable { // if RDD was not generated, and if the time is valid // (based on sliding time of this DStream), then generate the RDD - case None => + case None => { if (isTimeValid(time)) { compute(time) match { case Some(newRDD) => @@ -138,6 +137,7 @@ extends Logging with Serializable { } else { None } + } } } @@ -361,24 +361,19 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, - numPartitions: Int) + partitioner: Partitioner) extends DStream [(K,C)] (parent.ssc) { override def dependencies = List(parent) override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K,C)]] = { parent.getOrCompute(validTime) match { - case Some(rdd) => - val newrdd = { - if (numPartitions > 0) { - rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, numPartitions) - } else { - rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner) - } - } - Some(newrdd) + case Some(rdd) => + Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) case None => None } } diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index d2887c3aea..13db34ac80 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -1,71 +1,169 @@ package spark.streaming import scala.collection.mutable.ArrayBuffer +import spark.Partitioner +import spark.HashPartitioner import spark.streaming.StreamingContext._ +import javax.annotation.Nullable class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) extends Serializable { def ssc = stream.ssc + def defaultPartitioner(numPartitions: Int = stream.ssc.sc.defaultParallelism) = { + new HashPartitioner(numPartitions) + } + /* ---------------------------------- */ /* DStream operations for key-value pairs */ /* ---------------------------------- */ - - def groupByKey(numPartitions: Int = 0): ShuffledDStream[K, V, ArrayBuffer[V]] = { + + def groupByKey(): ShuffledDStream[K, V, ArrayBuffer[V]] = { + groupByKey(defaultPartitioner()) + } + + def groupByKey(numPartitions: Int): ShuffledDStream[K, V, ArrayBuffer[V]] = { + groupByKey(defaultPartitioner(numPartitions)) + } + + def groupByKey(partitioner: Partitioner): ShuffledDStream[K, V, ArrayBuffer[V]] = { def createCombiner(v: V) = ArrayBuffer[V](v) def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) - combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, numPartitions) + combineByKey[ArrayBuffer[V]](createCombiner _, mergeValue _, mergeCombiner _, partitioner) } - - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int = 0): ShuffledDStream[K, V, V] = { + + def reduceByKey(reduceFunc: (V, V) => V): ShuffledDStream[K, V, V] = { + reduceByKey(reduceFunc, defaultPartitioner()) + } + + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): ShuffledDStream[K, V, V] = { + reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) + } + + def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): ShuffledDStream[K, V, V] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, numPartitions) + combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) } private def combineByKey[C: ClassManifest]( createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, - numPartitions: Int) : ShuffledDStream[K, V, C] = { - new ShuffledDStream[K, V, C](stream, createCombiner, mergeValue, mergeCombiner, numPartitions) + partitioner: Partitioner) : ShuffledDStream[K, V, C] = { + new ShuffledDStream[K, V, C](stream, createCombiner, mergeValue, mergeCombiner, partitioner) + } + + def groupByKeyAndWindow(windowTime: Time, slideTime: Time): ShuffledDStream[K, V, ArrayBuffer[V]] = { + groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner()) + } + + def groupByKeyAndWindow(windowTime: Time, slideTime: Time, numPartitions: Int): ShuffledDStream[K, V, ArrayBuffer[V]] = { + groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner(numPartitions)) } def groupByKeyAndWindow( - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledDStream[K, V, ArrayBuffer[V]] = { - stream.window(windowTime, slideTime).groupByKey(numPartitions) + windowTime: Time, + slideTime: Time, + partitioner: Partitioner + ): ShuffledDStream[K, V, ArrayBuffer[V]] = { + stream.window(windowTime, slideTime).groupByKey(partitioner) + } + + def reduceByKeyAndWindow(reduceFunc: (V, V) => V, windowTime: Time, slideTime: Time): ShuffledDStream[K, V, V] = { + reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner()) + } + + def reduceByKeyAndWindow(reduceFunc: (V, V) => V, windowTime: Time, slideTime: Time, numPartitions: Int): ShuffledDStream[K, V, V] = { + reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions)) } def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int = 0): ShuffledDStream[K, V, V] = { - stream.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), numPartitions) + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + partitioner: Partitioner + ): ShuffledDStream[K, V, V] = { + stream.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner) } - // This method is the efficient sliding window reduce operation, - // which requires the specification of an inverse reduce function, - // so that new elements introduced in the window can be "added" using - // reduceFunc to the previous window's result and old elements can be + // This method is the efficient sliding window reduce operation, + // which requires the specification of an inverse reduce function, + // so that new elements introduced in the window can be "added" using + // reduceFunc to the previous window's result and old elements can be // "subtracted using invReduceFunc. def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, - numPartitions: Int): ReducedWindowedDStream[K, V] = { + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time + ): ReducedWindowedDStream[K, V] = { + + reduceByKeyAndWindow( + reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner()) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int + ): ReducedWindowedDStream[K, V] = { + + reduceByKeyAndWindow( + reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions)) + } + + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + partitioner: Partitioner + ): ReducedWindowedDStream[K, V] = { + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) new ReducedWindowedDStream[K, V]( - stream, - ssc.sc.clean(reduceFunc), - ssc.sc.clean(invReduceFunc), - windowTime, - slideTime, - numPartitions) + stream, cleanedReduceFunc, cleanedInvReduceFunc, windowTime, slideTime, partitioner) + } + + // TODO: + // + // + // + // + def updateStateByKey[S <: AnyRef : ClassManifest]( + updateFunc: (Seq[V], S) => S + ): StateDStream[K, V, S] = { + updateStateByKey(updateFunc, defaultPartitioner()) + } + + def updateStateByKey[S <: AnyRef : ClassManifest]( + updateFunc: (Seq[V], S) => S, + numPartitions: Int + ): StateDStream[K, V, S] = { + updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) + } + + def updateStateByKey[S <: AnyRef : ClassManifest]( + updateFunc: (Seq[V], S) => S, + partitioner: Partitioner + ): StateDStream[K, V, S] = { + val func = (iterator: Iterator[(K, Seq[V], S)]) => { + iterator.map(tuple => (tuple._1, updateFunc(tuple._2, tuple._3))) + } + updateStateByKey(func, partitioner, true) + } + + def updateStateByKey[S <: AnyRef : ClassManifest]( + updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], + partitioner: Partitioner, + rememberPartitioner: Boolean + ): StateDStream[K, V, S] = { + new StateDStream(stream, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner) } } diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 896e7dbafb..191d264b2b 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -5,7 +5,7 @@ import spark.streaming.StreamingContext._ import spark.RDD import spark.UnionRDD import spark.CoGroupedRDD -import spark.HashPartitioner +import spark.Partitioner import spark.SparkContext._ import spark.storage.StorageLevel @@ -17,8 +17,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( invReduceFunc: (V, V) => V, _windowTime: Time, _slideTime: Time, - numPartitions: Int) -extends DStream[(K,V)](parent.ssc) { + partitioner: Partitioner + ) extends DStream[(K,V)](parent.ssc) { if (!_windowTime.isMultipleOf(parent.slideTime)) throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " + @@ -28,7 +28,7 @@ extends DStream[(K,V)](parent.ssc) { throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - val reducedStream = parent.reduceByKey(reduceFunc, numPartitions) + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) val allowPartialWindows = true //reducedStream.persist(StorageLevel.MEMORY_ONLY_DESER_2) @@ -104,7 +104,7 @@ extends DStream[(K,V)](parent.ssc) { if (reducedRDDs.size == 0) { throw new Exception("Could not generate the first RDD for time " + validTime) } - return Some(new UnionRDD(ssc.sc, reducedRDDs).reduceByKey(reduceFunc, numPartitions)) + return Some(new UnionRDD(ssc.sc, reducedRDDs).reduceByKey(partitioner, reduceFunc)) } } } @@ -137,8 +137,7 @@ extends DStream[(K,V)](parent.ssc) { } t -= reducedStream.slideTime } - - val partitioner = new HashPartitioner(numPartitions) + val allRDDs = new ArrayBuffer[RDD[(_, _)]]() allRDDs += previousWindowRDD allRDDs ++= oldRDDs diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index da50e22719..12e52bf56c 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -41,7 +41,7 @@ extends Logging { def generateRDDs (time: Time) { println("\n-----------------------------------------------------\n") - logInfo("Generating RDDs for time " + time) + logInfo("Generating RDDs for time " + time) outputStreams.foreach(outputStream => { outputStream.generateJob(time) match { case Some(job) => submitJob(job) diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala new file mode 100644 index 0000000000..eabb33d89e --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -0,0 +1,83 @@ +package spark.streaming + +import spark.RDD +import spark.Partitioner +import spark.MapPartitionsRDD +import spark.SparkContext._ +import javax.annotation.Nullable + + +class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( + parent: DStream[(K, V)], + updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], + partitioner: Partitioner, + rememberPartitioner: Boolean + ) extends DStream[(K, S)](parent.ssc) { + + class SpecialMapPartitionsRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U]) + extends MapPartitionsRDD(prev, f) { + override val partitioner = if (rememberPartitioner) prev.partitioner else None + } + + override def dependencies = List(parent) + + override def slideTime = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, S)]] = { + + // Try to get the previous state RDD + getOrCompute(validTime - slideTime) match { + + case Some(prevStateRDD) => { // If previous state RDD exists + + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val func = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val i = iterator.map(t => { + (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) + }) + updateFunc(i) + } + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) + val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, func) + logDebug("Generating state RDD for time " + validTime) + return Some(stateRDD) + } + case None => { // If parent RDD does not exist, then return old state RDD + logDebug("Generating state RDD for time " + validTime + " (no change)") + return Some(prevStateRDD) + } + } + } + + case None => { // If previous session RDD does not exist (first input data) + + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val func = (iterator: Iterator[(K, Seq[V])]) => { + updateFunc(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) + } + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + val groupedRDD = parentRDD.groupByKey(partitioner) + val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, func) + logDebug("Generating state RDD for time " + validTime + " (first)") + return Some(sessionRDD) + } + case None => { // If parent RDD does not exist, then nothing to do! + logDebug("Not generating state RDD (no previous state, no parent)") + return None + } + } + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala index 0a33a05bae..0aa5294a17 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala @@ -14,8 +14,8 @@ object WordCountNetwork { val ssc = new StreamingContext(args(0), "WordCountNetwork") ssc.setBatchDuration(Seconds(2)) - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') val lines = ssc.createNetworkTextStream(args(1), args(2).toInt) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala index e9dc377263..d5eb20b37e 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala @@ -1,14 +1,14 @@ package spark.streaming import spark.Logging -import spark.RDD +import spark.streaming.StreamingContext._ import spark.streaming.util.ManualClock import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.SynchronizedQueue +import scala.runtime.RichInt class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { @@ -20,7 +20,9 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { def testOp[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]]) { + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { try { ssc = new StreamingContext("local", "test") ssc.setBatchDuration(Milliseconds(batchDurationMillis)) @@ -33,45 +35,89 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] clock.addToTime(input.size * batchDurationMillis) - Thread.sleep(100) + Thread.sleep(1000) val output = new ArrayBuffer[Seq[V]]() while(outputQueue.size > 0) { val rdd = outputQueue.take() output += (rdd.collect()) } + assert(output.size === expectedOutput.size) for (i <- 0 until output.size) { - assert(output(i).toList === expectedOutput(i).toList) + if (useSet) { + assert(output(i).toSet === expectedOutput(i).toSet) + } else { + assert(output(i).toList === expectedOutput(i).toList) + } } } finally { ssc.stop() } } - - test("basic operations") { - val inputData = Array(1 to 4, 5 to 8, 9 to 12) + + test("map-like operations") { + val inputData = Seq(1 to 4, 5 to 8, 9 to 12) // map testOp(inputData, (r: DStream[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) // flatMap - testOp(inputData, (r: DStream[Int]) => r.flatMap(x => Array(x, x * 2)), - inputData.map(_.flatMap(x => Array(x, x * 2))) + testOp( + inputData, + (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)), + inputData.map(_.flatMap(x => Array(x, x * 2))) ) } -} -object DStreamSuite { - def main(args: Array[String]) { - try { - val r = new DStreamSuite() - val inputData = Array(1 to 4, 5 to 8, 9 to 12) - r.testOp(inputData, (r: DStream[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) - - } catch { - case e: Exception => e.printStackTrace() + test("shuffle-based operations") { + // reduceByKey + testOp( + Seq(Seq("a", "a", "b"), Seq("", ""), Seq()), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + true + ) + + // reduce + testOp( + Seq(1 to 4, 5 to 8, 9 to 12), + (s: DStream[Int]) => s.reduce(_ + _), + Seq(Seq(10), Seq(26), Seq(42)) + ) + } + + test("window-based operations") { + + } + + + test("stateful operations") { + val inputData = + Seq( + Seq("a", "b", "c"), + Seq("a", "b", "c"), + Seq("a", "b", "c") + ) + + val outputData = + Seq( + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 2)), + Seq(("a", 3), ("b", 3), ("c", 3)) + )//.map(array => array.toSeq.map(x => (x._1, new RichInt(x._2)))) + + val updateStateOp =(s: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: RichInt) => { + var newState = 0 + if (values != null) newState += values.reduce(_ + _) + if (state != null) newState += state.self + //println("values = " + values + ", state = " + state + ", " + " new state = " + newState) + new RichInt(newState) + } + s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) } - System.exit(0) + + testOp(inputData, updateStateOp, outputData, true) } -} \ No newline at end of file +} -- cgit v1.2.3 From f92d4a6ac1f349fdfc88d4ef3b122ae32385b17e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 31 Aug 2012 15:33:48 -0700 Subject: Better output messages for streaming job duration --- streaming/src/main/scala/spark/streaming/Job.scala | 2 +- streaming/src/main/scala/spark/streaming/JobManager.scala | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala index 0bd8343b9a..0bcb6fd8dc 100644 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -8,7 +8,7 @@ class Job(val time: Time, func: () => _) { val startTime = System.currentTimeMillis func() val stopTime = System.currentTimeMillis - (startTime - stopTime) + (stopTime - startTime) } override def toString = "streaming job " + id + " @ " + time diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 2a4fe3dd11..8a652cbfca 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -12,10 +12,8 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { SparkEnv.set(ssc.env) try { val timeTaken = job.run() - logInfo( - "Running " + job + " took " + timeTaken + " ms, " + - "total delay was " + (System.currentTimeMillis - job.time) + " ms" - ) + logInfo("Total delay: %.4f s for job %s; execution was %.4f s".format( + System.currentTimeMillis() - job.time, timeTaken)) } catch { case e: Exception => logError("Running " + job + " failed", e) -- cgit v1.2.3 From ce42a463750f37d9c1c6a2c982d3b947039233cd Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 31 Aug 2012 15:35:35 -0700 Subject: Bug fix --- streaming/src/main/scala/spark/streaming/JobManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 8a652cbfca..953493352c 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -13,7 +13,7 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { try { val timeTaken = job.run() logInfo("Total delay: %.4f s for job %s; execution was %.4f s".format( - System.currentTimeMillis() - job.time, timeTaken)) + (System.currentTimeMillis() - job.time) / 1000.0, timeTaken / 1000.0)) } catch { case e: Exception => logError("Running " + job + " failed", e) -- cgit v1.2.3 From 51fb13dd164c29e6bf97c2e3642b07a7f416ddaa Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 31 Aug 2012 15:36:11 -0700 Subject: Bug fix --- streaming/src/main/scala/spark/streaming/JobManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 953493352c..40e614b4ed 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -13,7 +13,7 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { try { val timeTaken = job.run() logInfo("Total delay: %.4f s for job %s; execution was %.4f s".format( - (System.currentTimeMillis() - job.time) / 1000.0, timeTaken / 1000.0)) + (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0)) } catch { case e: Exception => logError("Running " + job + " failed", e) -- cgit v1.2.3 From c42e7ac2822f697a355650a70379d9e4ce2022c0 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 04:31:11 +0000 Subject: More block manager fixes --- .../scala/spark/storage/BlockManagerWorker.scala | 2 +- core/src/main/scala/spark/storage/BlockStore.scala | 30 ++++++++++------------ 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index d74cdb38a8..0658a57187 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -73,7 +73,7 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) blockManager.putBytes(id, bytes, level) logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.array().length) + + " with data size: " + bytes.limit) } private def getBlock(id: String): ByteBuffer = { diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 17f4f51aa8..77e0ed84c5 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -76,11 +76,11 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) currentMemory += sizeEstimate logDebug("Block " + blockId + " stored as values to memory") } else { - val entry = new Entry(bytes, bytes.array().length, false) - ensureFreeSpace(bytes.array.length) + val entry = new Entry(bytes, bytes.limit, false) + ensureFreeSpace(bytes.limit) memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory") + currentMemory += bytes.limit + logDebug("Block " + blockId + " stored as " + bytes.limit + " bytes to memory") } } @@ -97,11 +97,11 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) return Left(elements.iterator) } else { val bytes = dataSerialize(values) - ensureFreeSpace(bytes.array().length) - val entry = new Entry(bytes, bytes.array().length, false) + ensureFreeSpace(bytes.limit) + val entry = new Entry(bytes, bytes.limit, false) memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory") + currentMemory += bytes.limit + logDebug("Block " + blockId + " stored as " + bytes.limit + " bytes to memory") return Right(bytes) } } @@ -118,7 +118,7 @@ class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (entry.deserialized) { return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator) } else { - return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer])) + return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer].duplicate())) } } @@ -199,11 +199,11 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) val file = createFile(blockId) if (file != null) { val channel = new RandomAccessFile(file, "rw").getChannel() - val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length) - buffer.put(bytes.array) + val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.limit) + buffer.put(bytes) channel.close() val finishTime = System.currentTimeMillis - logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms") + logDebug("Block " + blockId + " stored to file of " + bytes.limit + " bytes to disk in " + (finishTime - startTime) + " ms") } else { logError("File not created for block " + blockId) } @@ -211,7 +211,7 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = { val bytes = dataSerialize(values) - logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes") + logDebug("Converted block " + blockId + " to " + bytes.limit + " bytes") putBytes(blockId, bytes, level) return Right(bytes) } @@ -220,9 +220,7 @@ class DiskStore(blockManager: BlockManager, rootDirs: String) val file = getFile(blockId) val length = file.length().toInt val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = ByteBuffer.allocate(length) - bytes.put(channel.map(MapMode.READ_WRITE, 0, length)) - return Some(bytes) + Some(channel.map(MapMode.READ_WRITE, 0, length)) } def getValues(blockId: String): Option[Iterator[Any]] = { -- cgit v1.2.3 From 44758aa8e2337364610ee80fa9ec913301712078 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 1 Sep 2012 00:17:59 -0700 Subject: First work towards a RawInputDStream and a sender program for it. --- .../src/main/scala/spark/DaemonThreadFactory.scala | 12 +- .../scala/spark/util/RateLimitedOutputStream.scala | 56 +++++ .../spark/streaming/NetworkInputDStream.scala | 33 +-- .../spark/streaming/NetworkInputReceiver.scala | 248 --------------------- .../streaming/NetworkInputReceiverMessage.scala | 7 + .../spark/streaming/NetworkInputTracker.scala | 7 +- .../scala/spark/streaming/ObjectInputDStream.scala | 16 ++ .../spark/streaming/ObjectInputReceiver.scala | 244 ++++++++++++++++++++ .../scala/spark/streaming/RawInputDStream.scala | 114 ++++++++++ .../scala/spark/streaming/StreamingContext.scala | 8 +- .../scala/spark/streaming/util/RawTextSender.scala | 51 +++++ 11 files changed, 512 insertions(+), 284 deletions(-) create mode 100644 core/src/main/scala/spark/util/RateLimitedOutputStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala create mode 100644 streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala create mode 100644 streaming/src/main/scala/spark/streaming/RawInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextSender.scala diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala index 003880c5e8..56e59adeb7 100644 --- a/core/src/main/scala/spark/DaemonThreadFactory.scala +++ b/core/src/main/scala/spark/DaemonThreadFactory.scala @@ -6,9 +6,13 @@ import java.util.concurrent.ThreadFactory * A ThreadFactory that creates daemon threads */ private object DaemonThreadFactory extends ThreadFactory { - override def newThread(r: Runnable): Thread = { - val t = new Thread(r) - t.setDaemon(true) - return t + override def newThread(r: Runnable): Thread = new DaemonThread(r) +} + +private class DaemonThread(r: Runnable = null) extends Thread { + override def run() { + if (r != null) { + r.run() + } } } \ No newline at end of file diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala new file mode 100644 index 0000000000..10f2272707 --- /dev/null +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -0,0 +1,56 @@ +package spark.util + +import java.io.OutputStream + +class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + var lastSyncTime = System.nanoTime() + var bytesWrittenSinceSync: Long = 0 + + override def write(b: Int) { + waitToWrite(1) + out.write(b) + } + + override def write(bytes: Array[Byte]) { + write(bytes, 0, bytes.length) + } + + override def write(bytes: Array[Byte], offset: Int, length: Int) { + val CHUNK_SIZE = 8192 + var pos = 0 + while (pos < length) { + val writeSize = math.min(length - pos, CHUNK_SIZE) + waitToWrite(writeSize) + out.write(bytes, offset + pos, length - pos) + pos += writeSize + } + } + + def waitToWrite(numBytes: Int) { + while (true) { + val now = System.nanoTime() + val elapsed = math.max(now - lastSyncTime, 1) + val rate = bytesWrittenSinceSync.toDouble / (elapsed / 1.0e9) + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + (1e10).toLong) { + // Ten seconds have passed since lastSyncTime; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + return + } else { + Thread.sleep(5) + } + } + } + + override def flush() { + out.flush() + } + + override def close() { + out.close() + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index ee09324c8c..bf83f98ec4 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -1,36 +1,23 @@ package spark.streaming -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ - import spark.RDD import spark.BlockRDD -import spark.Logging - -import java.io.InputStream +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc: StreamingContext) + extends InputDStream[T](ssc) { -class NetworkInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, - val host: String, - val port: Int, - val bytesToObjects: InputStream => Iterator[T] - ) extends InputDStream[T](ssc) with Logging { - val id = ssc.getNewNetworkStreamId() - def start() { } + def start() {} - def stop() { } + def stop() {} override def compute(validTime: Time): Option[RDD[T]] = { val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) - return Some(new BlockRDD[T](ssc.sc, blockIds)) + Some(new BlockRDD[T](ssc.sc, blockIds)) } - - def createReceiver(): NetworkInputReceiver[T] = { - new NetworkInputReceiver(id, host, port, bytesToObjects) - } -} \ No newline at end of file + + /** Called on workers to run a receiver for the stream. */ + def runReceiver(): Unit +} + diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala b/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala deleted file mode 100644 index 7add6246b7..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkInputReceiver.scala +++ /dev/null @@ -1,248 +0,0 @@ -package spark.streaming - -import spark.Logging -import spark.storage.BlockManager -import spark.storage.StorageLevel -import spark.SparkEnv -import spark.streaming.util.SystemClock -import spark.streaming.util.RecurringTimer - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Queue -import scala.collection.mutable.SynchronizedPriorityQueue -import scala.math.Ordering - -import java.net.InetSocketAddress -import java.net.Socket -import java.io.InputStream -import java.io.BufferedInputStream -import java.io.DataInputStream -import java.io.EOFException -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.ArrayBlockingQueue - -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ - -trait NetworkInputReceiverMessage -case class GetBlockIds(time: Long) extends NetworkInputReceiverMessage -case class GotBlockIds(streamId: Int, blocksIds: Array[String]) extends NetworkInputReceiverMessage -case class StopReceiver() extends NetworkInputReceiverMessage - -class NetworkInputReceiver[T: ClassManifest](streamId: Int, host: String, port: Int, bytesToObjects: InputStream => Iterator[T]) -extends Logging { - - class ReceiverActor extends Actor { - override def preStart() = { - logInfo("Attempting to register") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 100.milliseconds - val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - def receive = { - case GetBlockIds(time) => { - logInfo("Got request for block ids for " + time) - sender ! GotBlockIds(streamId, dataHandler.getPushedBlocks()) - } - - case StopReceiver() => { - if (receivingThread != null) { - receivingThread.interrupt() - } - sender ! true - } - } - } - - class DataHandler { - - class Block(val time: Long, val iterator: Iterator[T]) { - val blockId = "input-" + streamId + "-" + time - var pushed = true - override def toString() = "input block " + blockId - } - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockOrdering = new Ordering[Block] { - def compare(b1: Block, b2: Block) = (b1.time - b2.time).toInt - } - val blockStorageLevel = StorageLevel.DISK_AND_MEMORY - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blocksForReporting = new SynchronizedPriorityQueue[Block]()(blockOrdering) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val newBlock = new Block(time - blockInterval, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - blocksForReporting.enqueue(newBlock) - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - if (blockManager != null) { - blockManager.put(block.blockId, block.iterator, blockStorageLevel) - block.pushed = true - } else { - logWarning(block + " not put as block manager is null") - } - } - } catch { - case ie: InterruptedException => println("Block pushing thread interrupted") - case e: Exception => e.printStackTrace() - } - } - - def getPushedBlocks(): Array[String] = { - val pushedBlocks = new ArrayBuffer[String]() - var loop = true - while(loop && !blocksForReporting.isEmpty) { - val block = blocksForReporting.dequeue() - if (block == null) { - loop = false - } else if (!block.pushed) { - blocksForReporting.enqueue(block) - } else { - pushedBlocks += block.blockId - } - } - logInfo("Got " + pushedBlocks.size + " blocks") - pushedBlocks.toArray - } - } - - val blockManager = if (SparkEnv.get != null) SparkEnv.get.blockManager else null - val dataHandler = new DataHandler() - val env = SparkEnv.get - - var receiverActor: ActorRef = null - var receivingThread: Thread = null - - def run() { - initLogging() - var socket: Socket = null - try { - if (SparkEnv.get != null) { - receiverActor = SparkEnv.get.actorSystem.actorOf(Props(new ReceiverActor), "ReceiverActor-" + streamId) - } - dataHandler.start() - socket = connect() - receivingThread = Thread.currentThread() - receive(socket) - } catch { - case ie: InterruptedException => logInfo("Receiver interrupted") - } finally { - receivingThread = null - if (socket != null) socket.close() - dataHandler.stop() - } - } - - def connect(): Socket = { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - socket - } - - def receive(socket: Socket) { - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - dataHandler += obj - } - } -} - - -object NetworkInputReceiver { - - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val bufferedInputStream = new BufferedInputStream(inputStream) - val dataInputStream = new DataInputStream(bufferedInputStream) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - println("[" + nextValue + "]") - } catch { - case eof: EOFException => - finished = true - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!gotNext) { - getNext() - } - if (finished) { - dataInputStream.close() - } - !finished - } - - - override def next(): String = { - if (!gotNext) { - getNext() - } - if (finished) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - } - iterator - } - - def main(args: Array[String]) { - if (args.length < 2) { - println("NetworkReceiver ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val receiver = new NetworkInputReceiver(0, host, port, bytesToLines) - receiver.run() - } -} diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala b/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala new file mode 100644 index 0000000000..deaffe98c8 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala @@ -0,0 +1,7 @@ +package spark.streaming + +sealed trait NetworkInputReceiverMessage + +case class GetBlockIds(time: Long) extends NetworkInputReceiverMessage +case class GotBlockIds(streamId: Int, blocksIds: Array[String]) extends NetworkInputReceiverMessage +case object StopReceiver extends NetworkInputReceiverMessage diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 07758665c9..acf97c1883 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -52,9 +52,7 @@ extends Logging { if (!iterator.hasNext) { throw new Exception("Could not start receiver as details not found.") } - val stream = iterator.next - val receiver = stream.createReceiver() - receiver.run() + iterator.next().runReceiver() } ssc.sc.runJob(tempRDD, startReceiver) @@ -62,8 +60,7 @@ extends Logging { def stopReceivers() { implicit val ec = env.actorSystem.dispatcher - val message = new StopReceiver() - val listOfFutures = receiverInfo.values.map(_.ask(message)(timeout)).toList + val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList val futureOfList = Future.sequence(listOfFutures) Await.result(futureOfList, timeout) } diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala new file mode 100644 index 0000000000..2396b374a0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala @@ -0,0 +1,16 @@ +package spark.streaming + +import java.io.InputStream + +class ObjectInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + val host: String, + val port: Int, + val bytesToObjects: InputStream => Iterator[T]) + extends NetworkInputDStream[T](ssc) { + + override def runReceiver() { + new ObjectInputReceiver(id, host, port, bytesToObjects).run() + } +} + diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala b/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala new file mode 100644 index 0000000000..70fa2cdf07 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala @@ -0,0 +1,244 @@ +package spark.streaming + +import spark.Logging +import spark.storage.BlockManager +import spark.storage.StorageLevel +import spark.SparkEnv +import spark.streaming.util.SystemClock +import spark.streaming.util.RecurringTimer + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue +import scala.collection.mutable.SynchronizedPriorityQueue +import scala.math.Ordering + +import java.net.InetSocketAddress +import java.net.Socket +import java.io.InputStream +import java.io.BufferedInputStream +import java.io.DataInputStream +import java.io.EOFException +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ArrayBlockingQueue + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ + +class ObjectInputReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T]) + extends Logging { + + class ReceiverActor extends Actor { + override def preStart() { + logInfo("Attempting to register") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 1.seconds + val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + def receive = { + case GetBlockIds(time) => { + logInfo("Got request for block ids for " + time) + sender ! GotBlockIds(streamId, dataHandler.getPushedBlocks()) + } + + case StopReceiver => { + if (receivingThread != null) { + receivingThread.interrupt() + } + sender ! true + } + } + } + + class DataHandler { + class Block(val time: Long, val iterator: Iterator[T]) { + val blockId = "input-" + streamId + "-" + time + var pushed = true + override def toString = "input block " + blockId + } + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockOrdering = new Ordering[Block] { + def compare(b1: Block, b2: Block) = (b1.time - b2.time).toInt + } + val blockStorageLevel = StorageLevel.DISK_AND_MEMORY + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blocksForReporting = new SynchronizedPriorityQueue[Block]()(blockOrdering) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val newBlock = new Block(time - blockInterval, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + blocksForReporting.enqueue(newBlock) + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + if (blockManager != null) { + blockManager.put(block.blockId, block.iterator, blockStorageLevel) + block.pushed = true + } else { + logWarning(block + " not put as block manager is null") + } + } + } catch { + case ie: InterruptedException => println("Block pushing thread interrupted") + case e: Exception => e.printStackTrace() + } + } + + def getPushedBlocks(): Array[String] = { + val pushedBlocks = new ArrayBuffer[String]() + var loop = true + while(loop && !blocksForReporting.isEmpty) { + val block = blocksForReporting.dequeue() + if (block == null) { + loop = false + } else if (!block.pushed) { + blocksForReporting.enqueue(block) + } else { + pushedBlocks += block.blockId + } + } + logInfo("Got " + pushedBlocks.size + " blocks") + pushedBlocks.toArray + } + } + + val blockManager = if (SparkEnv.get != null) SparkEnv.get.blockManager else null + val dataHandler = new DataHandler() + val env = SparkEnv.get + + var receiverActor: ActorRef = null + var receivingThread: Thread = null + + def run() { + initLogging() + var socket: Socket = null + try { + if (SparkEnv.get != null) { + receiverActor = SparkEnv.get.actorSystem.actorOf(Props(new ReceiverActor), "ReceiverActor-" + streamId) + } + dataHandler.start() + socket = connect() + receivingThread = Thread.currentThread() + receive(socket) + } catch { + case ie: InterruptedException => logInfo("Receiver interrupted") + } finally { + receivingThread = null + if (socket != null) socket.close() + dataHandler.stop() + } + } + + def connect(): Socket = { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + socket + } + + def receive(socket: Socket) { + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } +} + + +object ObjectInputReceiver { + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val bufferedInputStream = new BufferedInputStream(inputStream) + val dataInputStream = new DataInputStream(bufferedInputStream) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + println("[" + nextValue + "]") + } catch { + case eof: EOFException => + finished = true + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!gotNext) { + getNext() + } + if (finished) { + dataInputStream.close() + } + !finished + } + + override def next(): String = { + if (!gotNext) { + getNext() + } + if (finished) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } + } + iterator + } + + def main(args: Array[String]) { + if (args.length < 2) { + println("ObjectInputReceiver ") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val receiver = new ObjectInputReceiver(0, host, port, bytesToLines) + receiver.run() + } +} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala new file mode 100644 index 0000000000..49e4781e75 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -0,0 +1,114 @@ +package spark.streaming + +import akka.actor._ +import akka.pattern.ask +import akka.util.duration._ +import akka.dispatch._ +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, SocketChannel} +import java.io.EOFException +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.{DaemonThread, Logging, SparkEnv} +import spark.storage.StorageLevel + +/** + * An input stream that reads blocks of serialized objects from a given network address. + * The blocks will be inserted directly into the block store. This is the fastest way to get + * data into Spark Streaming, though it requires the sender to batch data and serialize it + * in the format that the system is configured with. + */ +class RawInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + host: String, + port: Int) + extends NetworkInputDStream[T](ssc) with Logging { + + val streamId = id + + /** Called on workers to run a receiver for the stream. */ + def runReceiver() { + val env = SparkEnv.get + val actor = env.actorSystem.actorOf( + Props(new ReceiverActor(env, Thread.currentThread)), "ReceiverActor-" + streamId) + + // Open a socket to the target address and keep reading from it + logInfo("Connecting to " + host + ":" + port) + val channel = SocketChannel.open() + channel.configureBlocking(true) + channel.connect(new InetSocketAddress(host, port)) + logInfo("Connected to " + host + ":" + port) + + val queue = new ArrayBlockingQueue[ByteBuffer](2) + + new DaemonThread { + override def run() { + var nextBlockNumber = 0 + while (true) { + val buffer = queue.take() + val blockId = "input-" + streamId + "-" + nextBlockNumber + nextBlockNumber += 1 + env.blockManager.putBytes(blockId, buffer, StorageLevel.MEMORY_ONLY_2) + actor ! BlockPublished(blockId) + } + } + }.start() + + val lengthBuffer = ByteBuffer.allocate(4) + while (true) { + lengthBuffer.clear() + readFully(channel, lengthBuffer) + lengthBuffer.flip() + val length = lengthBuffer.getInt() + val dataBuffer = ByteBuffer.allocate(length) + readFully(channel, dataBuffer) + dataBuffer.flip() + logInfo("Read a block with " + length + " bytes") + queue.put(dataBuffer) + } + } + + /** Read a buffer fully from a given Channel */ + private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { + while (dest.position < dest.limit) { + if (channel.read(dest) == -1) { + throw new EOFException("End of channel") + } + } + } + + /** Message sent to ReceiverActor to tell it that a block was published */ + case class BlockPublished(blockId: String) {} + + /** A helper actor that communicates with the NetworkInputTracker */ + private class ReceiverActor(env: SparkEnv, receivingThread: Thread) extends Actor { + val newBlocks = new ArrayBuffer[String] + + override def preStart() { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 1.seconds + val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive = { + case BlockPublished(blockId) => + newBlocks += blockId + + case GetBlockIds(time) => + logInfo("Got request for block IDs for " + time) + sender ! GotBlockIds(streamId, newBlocks.toArray) + newBlocks.clear() + + case StopReceiver => + receivingThread.interrupt() + sender ! true + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 0ac86cbdf2..feb769e036 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -52,22 +52,22 @@ class StreamingContext ( private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() def createNetworkTextStream(hostname: String, port: Int): DStream[String] = { - createNetworkStream[String](hostname, port, NetworkInputReceiver.bytesToLines) + createNetworkObjectStream[String](hostname, port, ObjectInputReceiver.bytesToLines) } - def createNetworkStream[T: ClassManifest]( + def createNetworkObjectStream[T: ClassManifest]( hostname: String, port: Int, converter: (InputStream) => Iterator[T] ): DStream[T] = { - val inputStream = new NetworkInputDStream[T](this, hostname, port, converter) + val inputStream = new ObjectInputDStream[T](this, hostname, port, converter) inputStreams += inputStream inputStream } /* def createHttpTextStream(url: String): DStream[String] = { - createHttpStream(url, NetworkInputReceiver.bytesToLines) + createHttpStream(url, ObjectInputReceiver.bytesToLines) } def createHttpStream[T: ClassManifest]( diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala new file mode 100644 index 0000000000..60d5849d71 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -0,0 +1,51 @@ +package spark.streaming.util + +import spark.util.{RateLimitedOutputStream, IntParam} +import java.net.ServerSocket +import spark.{Logging, KryoSerializer} +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import io.Source +import java.io.IOException + +/** + * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a + * specified rate. Used to feed data into RawInputDStream. + */ +object RawTextSender extends Logging { + def main(args: Array[String]) { + if (args.length != 4) { + System.err.println("Usage: RawTextSender ") + } + // Parse the arguments using a pattern match + val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args + + // Repeat the input data multiple times to fill in a buffer + val lines = Source.fromFile(file).getLines().toArray + val bufferStream = new FastByteArrayOutputStream(blockSize + 1000) + val ser = new KryoSerializer().newInstance() + val serStream = ser.serializeStream(bufferStream) + var i = 0 + while (bufferStream.position < blockSize) { + serStream.writeObject(lines(i)) + i = (i + 1) % lines.length + } + bufferStream.trim() + val array = bufferStream.array + + val serverSocket = new ServerSocket(port) + + while (true) { + val socket = serverSocket.accept() + val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) + try { + while (true) { + out.write(array) + } + } catch { + case e: IOException => + logError("Socket closed: ", e) + socket.close() + } + } + } +} -- cgit v1.2.3 From f84d2bbe55aaf3ef7a6631b9018a573aa5729ff7 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 1 Sep 2012 00:31:15 -0700 Subject: Bug fixes to RateLimitedOutputStream --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 2 +- streaming/src/main/scala/spark/streaming/util/RawTextSender.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 10f2272707..d11ed163ce 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -21,7 +21,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu while (pos < length) { val writeSize = math.min(length - pos, CHUNK_SIZE) waitToWrite(writeSize) - out.write(bytes, offset + pos, length - pos) + out.write(bytes, offset + pos, writeSize) pos += writeSize } } diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index 60d5849d71..85927c02ec 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -15,6 +15,7 @@ object RawTextSender extends Logging { def main(args: Array[String]) { if (args.length != 4) { System.err.println("Usage: RawTextSender ") + System.exit(1) } // Parse the arguments using a pattern match val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args @@ -36,6 +37,7 @@ object RawTextSender extends Logging { while (true) { val socket = serverSocket.accept() + logInfo("Got a new connection") val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) try { while (true) { @@ -43,7 +45,7 @@ object RawTextSender extends Logging { } } catch { case e: IOException => - logError("Socket closed: ", e) + logError("Socket closed", e) socket.close() } } -- cgit v1.2.3 From 83dad56334e73c477e9b62715df14b0c798220e3 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 19:45:25 +0000 Subject: Further fixes to raw text sender, plus an app that uses it --- .../src/main/scala/spark/streaming/DStream.scala | 2 +- .../main/scala/spark/streaming/JobManager.scala | 2 +- .../scala/spark/streaming/RawInputDStream.scala | 5 ++-- .../scala/spark/streaming/StreamingContext.scala | 11 ++++++++ .../scala/spark/streaming/examples/CountRaw.scala | 32 ++++++++++++++++++++++ .../spark/streaming/examples/WordCount2.scala | 2 +- .../scala/spark/streaming/util/RawTextSender.scala | 7 +++++ 7 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/examples/CountRaw.scala diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 3a57488f9b..74140ab2b8 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -176,7 +176,7 @@ extends Logging with Serializable { def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc)) - def reduce(reduceFunc: (T, T) => T) = this.map(x => (1, x)).reduceByKey(reduceFunc, 1).map(_._2) + def reduce(reduceFunc: (T, T) => T) = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) def count() = this.map(_ => 1).reduce(_ + _) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 40e614b4ed..9bf9251519 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -12,7 +12,7 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { SparkEnv.set(ssc.env) try { val timeTaken = job.run() - logInfo("Total delay: %.4f s for job %s; execution was %.4f s".format( + logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format( (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0)) } catch { case e: Exception => diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index 49e4781e75..d59c245a23 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -22,7 +22,8 @@ import spark.storage.StorageLevel class RawInputDStream[T: ClassManifest]( @transient ssc: StreamingContext, host: String, - port: Int) + port: Int, + storageLevel: StorageLevel) extends NetworkInputDStream[T](ssc) with Logging { val streamId = id @@ -49,7 +50,7 @@ class RawInputDStream[T: ClassManifest]( val buffer = queue.take() val blockId = "input-" + streamId + "-" + nextBlockNumber nextBlockNumber += 1 - env.blockManager.putBytes(blockId, buffer, StorageLevel.MEMORY_ONLY_2) + env.blockManager.putBytes(blockId, buffer, storageLevel) actor ! BlockPublished(blockId) } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index feb769e036..cb0f9ceb15 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -4,6 +4,7 @@ import spark.RDD import spark.Logging import spark.SparkEnv import spark.SparkContext +import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Queue @@ -64,6 +65,16 @@ class StreamingContext ( inputStreams += inputStream inputStream } + + def createRawNetworkStream[T: ClassManifest]( + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_2 + ): DStream[T] = { + val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) + inputStreams += inputStream + inputStream + } /* def createHttpTextStream(url: String): DStream[String] = { diff --git a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala new file mode 100644 index 0000000000..17d1ce3602 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala @@ -0,0 +1,32 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object CountRaw { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: WordCountNetwork ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), hostname, IntParam(port)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "CountRaw") + ssc.setBatchDuration(Seconds(1)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to numStreams).map(_ => + ssc.createRawNetworkStream[String](hostname, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnifiedDStream(rawStreams) + union.map(_.length).reduce(_ + _).foreachRDD(r => println("Byte count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index a090dcb85d..ce553758a7 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -85,7 +85,7 @@ object WordCount2 { //warmup(ssc.sc) val data = ssc.sc.textFile(file, mapTasks.toInt).persist( - new StorageLevel(false, true, true, 2)) // Memory only, deserialized, 2 replicas + new StorageLevel(false, true, false, 2)) // Memory only, serialized, 2 replicas println("Data count: " + data.count()) println("Data count: " + data.count()) println("Data count: " + data.count()) diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index 85927c02ec..8db651ba19 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -1,5 +1,6 @@ package spark.streaming.util +import java.nio.ByteBuffer import spark.util.{RateLimitedOutputStream, IntParam} import java.net.ServerSocket import spark.{Logging, KryoSerializer} @@ -33,7 +34,12 @@ object RawTextSender extends Logging { bufferStream.trim() val array = bufferStream.array + val countBuf = ByteBuffer.wrap(new Array[Byte](4)) + countBuf.putInt(array.length) + countBuf.flip() + val serverSocket = new ServerSocket(port) + logInfo("Listening on port " + port) while (true) { val socket = serverSocket.accept() @@ -41,6 +47,7 @@ object RawTextSender extends Logging { val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec) try { while (true) { + out.write(countBuf.array) out.write(array) } } catch { -- cgit v1.2.3 From bf993cda632e4a9fa41035845491ae466d1a4431 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 19:59:23 +0000 Subject: Make batch size configurable in RawCount --- .../src/main/scala/spark/streaming/examples/CountRaw.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala index 17d1ce3602..c78c1e9660 100644 --- a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala @@ -7,16 +7,16 @@ import spark.streaming.StreamingContext._ object CountRaw { def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: WordCountNetwork ") + if (args.length != 5) { + System.err.println("Usage: CountRaw ") System.exit(1) } - val Array(master, IntParam(numStreams), hostname, IntParam(port)) = args + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args // Create the context and set the batch size val ssc = new StreamingContext(master, "CountRaw") - ssc.setBatchDuration(Seconds(1)) + ssc.setBatchDuration(Milliseconds(batchMillis)) // Make sure some tasks have started on each node ssc.sc.parallelize(1 to 1000, 1000).count() @@ -24,9 +24,9 @@ object CountRaw { ssc.sc.parallelize(1 to 1000, 1000).count() val rawStreams = (1 to numStreams).map(_ => - ssc.createRawNetworkStream[String](hostname, port, StorageLevel.MEMORY_ONLY_2)).toArray + ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnifiedDStream(rawStreams) - union.map(_.length).reduce(_ + _).foreachRDD(r => println("Byte count: " + r.collect().mkString)) + union.map(_.length + 2).reduce(_ + _).foreachRDD(r => println("Byte count: " + r.collect().mkString)) ssc.start() } } -- cgit v1.2.3 From 6025889be0ecf1c9849c5c940a7171c6d82be0b5 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 20:51:07 +0000 Subject: More raw network receiver programs --- .../mesos/CoarseMesosSchedulerBackend.scala | 4 ++- .../scala/spark/streaming/examples/GrepRaw.scala | 33 +++++++++++++++++ .../spark/streaming/examples/WordCount2.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 42 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 31784985dc..fdf007ffb2 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -80,6 +80,8 @@ class CoarseMesosSchedulerBackend( "property, the SPARK_HOME environment variable or the SparkContext constructor") } + val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt + var nextMesosTaskId = 0 def newMesosTaskId(): Int = { @@ -177,7 +179,7 @@ class CoarseMesosSchedulerBackend( val task = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", executorMemory)) diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala new file mode 100644 index 0000000000..cc52da7bd4 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -0,0 +1,33 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object GrepRaw { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: GrepRaw ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "GrepRaw") + ssc.setBatchDuration(Milliseconds(batchMillis)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to numStreams).map(_ => + ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnifiedDStream(rawStreams) + union.filter(_.contains("Culpepper")).count().foreachRDD(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index ce553758a7..8c2724e97c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -100,7 +100,7 @@ object WordCount2 { .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, Milliseconds(chkptMillis.toLong)) - windowedCounts.print() + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala new file mode 100644 index 0000000000..298d9ef381 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -0,0 +1,42 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object WordCountRaw { + def main(args: Array[String]) { + if (args.length != 7) { + System.err.println("Usage: WordCountRaw ") + System.exit(1) + } + + val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs), + IntParam(chkptMs), IntParam(reduces)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "WordCountRaw") + ssc.setBatchDuration(Milliseconds(batchMs)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to streams).map(_ => + ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnifiedDStream(rawStreams) + + import WordCount2_ExtraFunctions._ + + val windowedCounts = union.mapPartitions(splitAndCountPartitions) + .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, + Milliseconds(chkptMs)) + //windowedCounts.print() // TODO: something else? + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + + ssc.start() + } +} -- cgit v1.2.3 From ceabf71257631c9e46f82897e540369b99a6bb57 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 1 Sep 2012 21:52:42 +0000 Subject: tweaks --- core/src/main/scala/spark/storage/StorageLevel.scala | 1 + streaming/src/main/scala/spark/streaming/util/RawTextSender.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index f067a2a6c5..a64393eba7 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -66,6 +66,7 @@ class StorageLevel( object StorageLevel { val NONE = new StorageLevel(false, false, false) val DISK_ONLY = new StorageLevel(true, false, false) + val DISK_ONLY_2 = new StorageLevel(true, false, false, 2) val MEMORY_ONLY = new StorageLevel(false, true, false) val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2) val MEMORY_ONLY_DESER = new StorageLevel(false, true, true) diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index 8db651ba19..d8b987ec86 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -52,7 +52,7 @@ object RawTextSender extends Logging { } } catch { case e: IOException => - logError("Socket closed", e) + logError("Client disconnected") socket.close() } } -- cgit v1.2.3 From 7419d2c7ea1be9dcc0079dbe6dfd7046f0c549e0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 2 Sep 2012 02:35:17 -0700 Subject: Added transformRDD DStream operation and TransformedDStream. Added sbt assembly option for streaming project. --- project/SparkBuild.scala | 4 +- .../src/main/scala/spark/streaming/DStream.scala | 62 ++++++++++++++++------ .../main/scala/spark/streaming/StateDStream.scala | 1 - .../test/scala/spark/streaming/DStreamSuite.scala | 2 +- 4 files changed, 49 insertions(+), 20 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 6a60f10be4..358213fe64 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -84,7 +84,9 @@ object SparkBuild extends Build { def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") - def streamingSettings = sharedSettings ++ Seq(name := "spark-streaming") + def streamingSettings = sharedSettings ++ Seq( + name := "spark-streaming" + ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( mergeStrategy in assembly := { diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 8c06345933..08eda056c9 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -12,7 +12,7 @@ import spark.Partitioner import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.ArrayBlockingQueue abstract class DStream[T: ClassManifest] (@transient val ssc: StreamingContext) extends Logging with Serializable { @@ -166,15 +166,17 @@ extends Logging with Serializable { def map[U: ClassManifest](mapFunc: T => U) = new MappedDStream(this, ssc.sc.clean(mapFunc)) - def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = + def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) + } def filter(filterFunc: T => Boolean) = new FilteredDStream(this, filterFunc) def glom() = new GlommedDStream(this) - def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = + def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = { new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc)) + } def reduce(reduceFunc: (T, T) => T) = this.map(x => (1, x)).reduceByKey(reduceFunc, 1).map(_._2) @@ -182,18 +184,30 @@ extends Logging with Serializable { def collect() = this.map(x => (1, x)).groupByKey(1).map(_._2) - def foreach(foreachFunc: T => Unit) = { + def foreach(foreachFunc: T => Unit) { val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc)) ssc.registerOutputStream(newStream) newStream } - def foreachRDD(foreachFunc: RDD[T] => Unit) = { + def foreachRDD(foreachFunc: RDD[T] => Unit) { + foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) + } + + def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) ssc.registerOutputStream(newStream) newStream } + def transformRDD[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { + transformRDD((r: RDD[T], t: Time) => transformFunc(r)) + } + + def transformRDD[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { + new TransformedDStream(this, ssc.sc.clean(transformFunc)) + } + private[streaming] def toQueue = { val queue = new ArrayBlockingQueue[RDD[T]](10000) this.foreachRDD(rdd => { @@ -361,15 +375,13 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, - partitioner: Partitioner) - extends DStream [(K,C)] (parent.ssc) { + partitioner: Partitioner + ) extends DStream [(K,C)] (parent.ssc) { override def dependencies = List(parent) override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K,C)]] = { parent.getOrCompute(validTime) match { case Some(rdd) => @@ -385,7 +397,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( */ class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) -extends DStream[T](parents(0).ssc) { + extends DStream[T](parents(0).ssc) { if (parents.length == 0) { throw new IllegalArgumentException("Empty array of parents") @@ -424,8 +436,8 @@ extends DStream[T](parents(0).ssc) { class PerElementForEachDStream[T: ClassManifest] ( parent: DStream[T], - foreachFunc: T => Unit) -extends DStream[Unit](parent.ssc) { + foreachFunc: T => Unit + ) extends DStream[Unit](parent.ssc) { override def dependencies = List(parent) @@ -455,11 +467,8 @@ extends DStream[Unit](parent.ssc) { class PerRDDForEachDStream[T: ClassManifest] ( parent: DStream[T], - foreachFunc: (RDD[T], Time) => Unit) -extends DStream[Unit](parent.ssc) { - - def this(parent: DStream[T], altForeachFunc: (RDD[T]) => Unit) = - this(parent, (rdd: RDD[T], time: Time) => altForeachFunc(rdd)) + foreachFunc: (RDD[T], Time) => Unit + ) extends DStream[Unit](parent.ssc) { override def dependencies = List(parent) @@ -478,3 +487,22 @@ extends DStream[Unit](parent.ssc) { } } } + + +/** + * TODO + */ + +class TransformedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + transformFunc: (RDD[T], Time) => RDD[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) + } + } diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index eabb33d89e..f313d8c162 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -4,7 +4,6 @@ import spark.RDD import spark.Partitioner import spark.MapPartitionsRDD import spark.SparkContext._ -import javax.annotation.Nullable class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala index d5eb20b37e..030f351080 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala @@ -105,7 +105,7 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { Seq(("a", 1), ("b", 1), ("c", 1)), Seq(("a", 2), ("b", 2), ("c", 2)), Seq(("a", 3), ("b", 3), ("c", 3)) - )//.map(array => array.toSeq.map(x => (x._1, new RichInt(x._2)))) + ) val updateStateOp =(s: DStream[String]) => { val updateFunc = (values: Seq[Int], state: RichInt) => { -- cgit v1.2.3 From 1878731671d89a2e2230f1cae62648fcd69c12ab Mon Sep 17 00:00:00 2001 From: root Date: Tue, 4 Sep 2012 04:26:53 +0000 Subject: Various test programs --- .../scala/spark/streaming/examples/Grep2.scala | 64 +++++++++++++++++++ .../spark/streaming/examples/WordCount2.scala | 8 ++- .../scala/spark/streaming/examples/WordMax2.scala | 73 ++++++++++++++++++++++ 3 files changed, 143 insertions(+), 2 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/examples/Grep2.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/WordMax2.scala diff --git a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala new file mode 100644 index 0000000000..7237142c7c --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala @@ -0,0 +1,64 @@ +package spark.streaming.examples + +import spark.SparkContext +import SparkContext._ +import spark.streaming._ +import StreamingContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + + +object Grep2 { + + def warmup(sc: SparkContext) { + (0 until 10).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(x => (x % 337, x % 1331)) + .reduceByKey(_ + _) + .count() + } + } + + def main (args: Array[String]) { + + if (args.length != 6) { + println ("Usage: Grep2 ") + System.exit(1) + } + + val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args + + val batchDuration = Milliseconds(batchMillis.toLong) + + val ssc = new StreamingContext(master, "Grep2") + ssc.setBatchDuration(batchDuration) + + //warmup(ssc.sc) + + val data = ssc.sc.textFile(file, mapTasks.toInt).persist( + new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas + println("Data count: " + data.count()) + println("Data count: " + data.count()) + println("Data count: " + data.count()) + + val sentences = new ConstantInputDStream(ssc, data) + ssc.inputStreams += sentences + + sentences.filter(_.contains("Culpepper")).count().foreachRDD(r => + println("Grep count: " + r.collect().mkString)) + + ssc.start() + + while(true) { Thread.sleep(1000) } + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 8c2724e97c..aa542ba07d 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -22,6 +22,8 @@ object WordCount2_ExtraFunctions { def subtract(v1: Long, v2: Long) = (v1 - v2) + def max(v1: Long, v2: Long) = math.max(v1, v2) + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { //val map = new java.util.HashMap[String, Long] val map = new OLMap[String] @@ -85,7 +87,7 @@ object WordCount2 { //warmup(ssc.sc) val data = ssc.sc.textFile(file, mapTasks.toInt).persist( - new StorageLevel(false, true, false, 2)) // Memory only, serialized, 2 replicas + new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas println("Data count: " + data.count()) println("Data count: " + data.count()) println("Data count: " + data.count()) @@ -98,7 +100,9 @@ object WordCount2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) - windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, + StorageLevel.MEMORY_ONLY_DESER_2, + //new StorageLevel(false, true, true, 3), Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala new file mode 100644 index 0000000000..3658cb302d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -0,0 +1,73 @@ +package spark.streaming.examples + +import spark.SparkContext +import SparkContext._ +import spark.streaming._ +import StreamingContext._ + +import spark.storage.StorageLevel + +import scala.util.Sorting +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue +import scala.collection.JavaConversions.mapAsScalaMap + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + + +object WordMax2 { + + def warmup(sc: SparkContext) { + (0 until 10).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(x => (x % 337, x % 1331)) + .reduceByKey(_ + _) + .count() + } + } + + def main (args: Array[String]) { + + if (args.length != 6) { + println ("Usage: WordMax2 ") + System.exit(1) + } + + val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args + + val batchDuration = Milliseconds(batchMillis.toLong) + + val ssc = new StreamingContext(master, "WordMax2") + ssc.setBatchDuration(batchDuration) + + //warmup(ssc.sc) + + val data = ssc.sc.textFile(file, mapTasks.toInt).persist( + new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas + println("Data count: " + data.count()) + println("Data count: " + data.count()) + println("Data count: " + data.count()) + + val sentences = new ConstantInputDStream(ssc, data) + ssc.inputStreams += sentences + + import WordCount2_ExtraFunctions._ + + val windowedCounts = sentences + .mapPartitions(splitAndCountPartitions) + .reduceByKey(add _, reduceTasks.toInt) + .persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, + Milliseconds(chkptMillis.toLong)) + .reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt) + //.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, + // Milliseconds(chkptMillis.toLong)) + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + + ssc.start() + + while(true) { Thread.sleep(1000) } + } +} + + -- cgit v1.2.3 From 389a78722cabe9f964ac29edcc0c3d47db4ba021 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 4 Sep 2012 15:37:46 -0700 Subject: Updated the return types of PairDStreamFunctions to return DStreams instead of ShuffleDStreams for cleaner abstraction. --- .../spark/streaming/PairDStreamFunctions.scala | 56 ++++++++++++++-------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 13db34ac80..3fd0a16bf0 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -19,32 +19,32 @@ extends Serializable { /* DStream operations for key-value pairs */ /* ---------------------------------- */ - def groupByKey(): ShuffledDStream[K, V, ArrayBuffer[V]] = { + def groupByKey(): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner()) } - def groupByKey(numPartitions: Int): ShuffledDStream[K, V, ArrayBuffer[V]] = { + def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner(numPartitions)) } - def groupByKey(partitioner: Partitioner): ShuffledDStream[K, V, ArrayBuffer[V]] = { + def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { def createCombiner(v: V) = ArrayBuffer[V](v) def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) - combineByKey[ArrayBuffer[V]](createCombiner _, mergeValue _, mergeCombiner _, partitioner) + combineByKey(createCombiner _, mergeValue _, mergeCombiner _, partitioner).asInstanceOf[DStream[(K, Seq[V])]] } - def reduceByKey(reduceFunc: (V, V) => V): ShuffledDStream[K, V, V] = { + def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner()) } - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): ShuffledDStream[K, V, V] = { + def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) } - def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): ShuffledDStream[K, V, V] = { + def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey[V]((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) + combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) } private def combineByKey[C: ClassManifest]( @@ -55,11 +55,15 @@ extends Serializable { new ShuffledDStream[K, V, C](stream, createCombiner, mergeValue, mergeCombiner, partitioner) } - def groupByKeyAndWindow(windowTime: Time, slideTime: Time): ShuffledDStream[K, V, ArrayBuffer[V]] = { + def groupByKeyAndWindow(windowTime: Time, slideTime: Time): DStream[(K, Seq[V])] = { groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner()) } - def groupByKeyAndWindow(windowTime: Time, slideTime: Time, numPartitions: Int): ShuffledDStream[K, V, ArrayBuffer[V]] = { + def groupByKeyAndWindow( + windowTime: Time, + slideTime: Time, + numPartitions: Int + ): DStream[(K, Seq[V])] = { groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner(numPartitions)) } @@ -67,15 +71,24 @@ extends Serializable { windowTime: Time, slideTime: Time, partitioner: Partitioner - ): ShuffledDStream[K, V, ArrayBuffer[V]] = { + ): DStream[(K, Seq[V])] = { stream.window(windowTime, slideTime).groupByKey(partitioner) } - def reduceByKeyAndWindow(reduceFunc: (V, V) => V, windowTime: Time, slideTime: Time): ShuffledDStream[K, V, V] = { + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time + ): DStream[(K, V)] = { reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner()) } - def reduceByKeyAndWindow(reduceFunc: (V, V) => V, windowTime: Time, slideTime: Time, numPartitions: Int): ShuffledDStream[K, V, V] = { + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time, + slideTime: Time, + numPartitions: Int + ): DStream[(K, V)] = { reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions)) } @@ -84,7 +97,7 @@ extends Serializable { windowTime: Time, slideTime: Time, partitioner: Partitioner - ): ShuffledDStream[K, V, V] = { + ): DStream[(K, V)] = { stream.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner) } @@ -93,12 +106,13 @@ extends Serializable { // so that new elements introduced in the window can be "added" using // reduceFunc to the previous window's result and old elements can be // "subtracted using invReduceFunc. + def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, windowTime: Time, slideTime: Time - ): ReducedWindowedDStream[K, V] = { + ): DStream[(K, V)] = { reduceByKeyAndWindow( reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner()) @@ -110,7 +124,7 @@ extends Serializable { windowTime: Time, slideTime: Time, numPartitions: Int - ): ReducedWindowedDStream[K, V] = { + ): DStream[(K, V)] = { reduceByKeyAndWindow( reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions)) @@ -122,7 +136,7 @@ extends Serializable { windowTime: Time, slideTime: Time, partitioner: Partitioner - ): ReducedWindowedDStream[K, V] = { + ): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) @@ -137,21 +151,21 @@ extends Serializable { // def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], S) => S - ): StateDStream[K, V, S] = { + ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner()) } def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], S) => S, numPartitions: Int - ): StateDStream[K, V, S] = { + ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) } def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], S) => S, partitioner: Partitioner - ): StateDStream[K, V, S] = { + ): DStream[(K, S)] = { val func = (iterator: Iterator[(K, Seq[V], S)]) => { iterator.map(tuple => (tuple._1, updateFunc(tuple._2, tuple._3))) } @@ -162,7 +176,7 @@ extends Serializable { updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean - ): StateDStream[K, V, S] = { + ): DStream[(K, S)] = { new StateDStream(stream, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner) } } -- cgit v1.2.3 From 2ff72f60ac6e42e9511fcf4effd1df89da8dc410 Mon Sep 17 00:00:00 2001 From: haoyuan Date: Tue, 4 Sep 2012 17:55:55 -0700 Subject: add TopKWordCountRaw.scala --- .../streaming/examples/TopKWordCountRaw.scala | 86 ++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala new file mode 100644 index 0000000000..7c01435f85 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -0,0 +1,86 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ +import spark.streaming.StreamingContext._ + +object WordCountRaw { + def main(args: Array[String]) { + if (args.length != 7) { + System.err.println("Usage: WordCountRaw ") + System.exit(1) + } + + val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs), + IntParam(chkptMs), IntParam(reduces)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "TopKWordCountRaw") + ssc.setBatchDuration(Milliseconds(batchMs)) + + // Make sure some tasks have started on each node + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + ssc.sc.parallelize(1 to 1000, 1000).count() + + val rawStreams = (1 to streams).map(_ => + ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + val union = new UnifiedDStream(rawStreams) + + import WordCount2_ExtraFunctions._ + + val windowedCounts = union.mapPartitions(splitAndCountPartitions) + .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) + windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, + Milliseconds(chkptMs)) + //windowedCounts.print() // TODO: something else? + + def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { + val taken = new Array[(String, JLong)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, JLong) = null + var swap: (String, JLong) = null + var count = 0 + + while(data.hasNext) { + value = data.next + count += 1 + println("count = " + count) + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + println("Took " + len + " out of " + count + " items") + return taken.toIterator + } + + val k = 50 + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreachRDD(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " items") + topK(collectedCounts.toIterator, k).foreach(println) + }) + +// windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + + ssc.start() + } +} -- cgit v1.2.3 From 96a1f2277d1c7e0ff970eafcb3ca53cd20b1b89c Mon Sep 17 00:00:00 2001 From: haoyuan Date: Tue, 4 Sep 2012 18:03:34 -0700 Subject: fix the compile error in TopKWordCountRaw.scala --- .../scala/spark/streaming/examples/TopKWordCountRaw.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 7c01435f85..be3188c5ed 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -5,10 +5,10 @@ import spark.storage.StorageLevel import spark.streaming._ import spark.streaming.StreamingContext._ -object WordCountRaw { +object TopKWordCountRaw { def main(args: Array[String]) { if (args.length != 7) { - System.err.println("Usage: WordCountRaw ") + System.err.println("Usage: TopKWordCountRaw ") System.exit(1) } @@ -36,14 +36,14 @@ object WordCountRaw { Milliseconds(chkptMs)) //windowedCounts.print() // TODO: something else? - def topK(data: Iterator[(String, JLong)], k: Int): Iterator[(String, JLong)] = { - val taken = new Array[(String, JLong)](k) + def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { + val taken = new Array[(String, Long)](k) var i = 0 var len = 0 var done = false - var value: (String, JLong) = null - var swap: (String, JLong) = null + var value: (String, Long) = null + var swap: (String, Long) = null var count = 0 while(data.hasNext) { -- cgit v1.2.3 From 7c09ad0e04639040864236cf13a9fedff6736b5d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 4 Sep 2012 19:11:49 -0700 Subject: Changed DStream member access permissions from private to protected. Updated StateDStream to checkpoint RDDs and forget lineage. --- core/src/main/scala/spark/RDD.scala | 2 +- .../src/main/scala/spark/streaming/DStream.scala | 16 ++-- .../scala/spark/streaming/QueueInputDStream.scala | 2 +- .../main/scala/spark/streaming/StateDStream.scala | 93 ++++++++++++++++------ .../test/scala/spark/streaming/DStreamSuite.scala | 4 +- 5 files changed, 81 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3fe8e8a4bf..d28f3593fe 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -94,7 +94,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def getStorageLevel = storageLevel - def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = { + def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 9b0115eef6..20f1c4db20 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -41,17 +41,17 @@ extends Logging with Serializable { */ // Variable to store the RDDs generated earlier in time - @transient private val generatedRDDs = new HashMap[Time, RDD[T]] () + @transient protected val generatedRDDs = new HashMap[Time, RDD[T]] () // Variable to be set to the first time seen by the DStream (effective time zero) - private[streaming] var zeroTime: Time = null + protected[streaming] var zeroTime: Time = null // Variable to specify storage level - private var storageLevel: StorageLevel = StorageLevel.NONE + protected var storageLevel: StorageLevel = StorageLevel.NONE // Checkpoint level and checkpoint interval - private var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - private var checkpointInterval: Time = null + protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint + protected var checkpointInterval: Time = null // Change this RDD's storage level def persist( @@ -84,7 +84,7 @@ extends Logging with Serializable { * the validity of future times is calculated. This method also recursively initializes * its parent DStreams. */ - def initialize(time: Time) { + protected[streaming] def initialize(time: Time) { if (zeroTime == null) { zeroTime = time } @@ -93,7 +93,7 @@ extends Logging with Serializable { } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ - private def isTimeValid (time: Time): Boolean = { + protected def isTimeValid (time: Time): Boolean = { if (!isInitialized) { throw new Exception (this.toString + " has not been initialized") } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { @@ -208,7 +208,7 @@ extends Logging with Serializable { new TransformedDStream(this, ssc.sc.clean(transformFunc)) } - private[streaming] def toQueue = { + def toQueue = { val queue = new ArrayBlockingQueue[RDD[T]](10000) this.foreachRDD(rdd => { queue.add(rdd) diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala index bab48ff954..de30297c7d 100644 --- a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer class QueueInputDStream[T: ClassManifest]( - ssc: StreamingContext, + @transient ssc: StreamingContext, val queue: Queue[RDD[T]], oneAtATime: Boolean, defaultRDD: RDD[T] diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index f313d8c162..4cb780c006 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -1,10 +1,11 @@ package spark.streaming import spark.RDD +import spark.BlockRDD import spark.Partitioner import spark.MapPartitionsRDD import spark.SparkContext._ - +import spark.storage.StorageLevel class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( parent: DStream[(K, V)], @@ -22,6 +23,47 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife override def slideTime = parent.slideTime + override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { + generatedRDDs.get(time) match { + case Some(oldRDD) => { + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { + val r = oldRDD + val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) + val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { + override val partitioner = oldRDD.partitioner + } + generatedRDDs.update(time, checkpointedRDD) + logInfo("Updated RDD of time " + time + " with its checkpointed version") + Some(checkpointedRDD) + } else { + Some(oldRDD) + } + } + case None => { + if (isTimeValid(time)) { + compute(time) match { + case Some(newRDD) => { + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.persist(checkpointLevel) + logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) + } else if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + } + generatedRDDs.put(time, newRDD) + Some(newRDD) + } + case None => { + None + } + } + } else { + None + } + } + } + } + override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD @@ -29,26 +71,27 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife case Some(prevStateRDD) => { // If previous state RDD exists - // Define the function for the mapPartition operation on cogrouped RDD; - // first map the cogrouped tuple to tuples of required type, - // and then apply the update function - val func = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { - val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) - }) - updateFunc(i) - } - // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val mapPartitionFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val i = iterator.map(t => { + (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) + }) + updateFuncLocal(i) + } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, func) - logDebug("Generating state RDD for time " + validTime) + val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, mapPartitionFunc) + //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } case None => { // If parent RDD does not exist, then return old state RDD - logDebug("Generating state RDD for time " + validTime + " (no change)") + //logDebug("Generating state RDD for time " + validTime + " (no change)") return Some(prevStateRDD) } } @@ -56,23 +99,25 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife case None => { // If previous session RDD does not exist (first input data) - // Define the function for the mapPartition operation on grouped RDD; - // first map the grouped tuple to tuples of required type, - // and then apply the update function - val func = (iterator: Iterator[(K, Seq[V])]) => { - updateFunc(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) - } - // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val mapPartitionFunc = (iterator: Iterator[(K, Seq[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) + } + val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, func) - logDebug("Generating state RDD for time " + validTime + " (first)") + val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, mapPartitionFunc) + //logDebug("Generating state RDD for time " + validTime + " (first)") return Some(sessionRDD) } case None => { // If parent RDD does not exist, then nothing to do! - logDebug("Not generating state RDD (no previous state, no parent)") + //logDebug("Not generating state RDD (no previous state, no parent)") return None } } diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala index 030f351080..fc00952afe 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala @@ -107,12 +107,12 @@ class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { Seq(("a", 3), ("b", 3), ("c", 3)) ) - val updateStateOp =(s: DStream[String]) => { + val updateStateOp = (s: DStream[String]) => { val updateFunc = (values: Seq[Int], state: RichInt) => { var newState = 0 if (values != null) newState += values.reduce(_ + _) if (state != null) newState += state.self - //println("values = " + values + ", state = " + state + ", " + " new state = " + newState) + println("values = " + values + ", state = " + state + ", " + " new state = " + newState) new RichInt(newState) } s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) -- cgit v1.2.3 From 4ea032a142ab7fb44f92b145cc8d850164419ab5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 05:53:07 +0000 Subject: Some changes to make important log output visible even if we set the logging to WARNING --- .../main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala | 2 +- streaming/src/main/scala/spark/streaming/JobManager.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 83e7c6e036..978b4f2676 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -99,7 +99,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Remove a disconnected slave from the cluster def removeSlave(slaveId: String) { - logInfo("Slave " + slaveId + " disconnected, so removing it") + logWarning("Slave " + slaveId + " disconnected, so removing it") val numCores = freeCores(slaveId) actorToSlaveId -= slaveActor(slaveId) addressToSlaveId -= slaveAddress(slaveId) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 9bf9251519..230d806a89 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -12,7 +12,7 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { SparkEnv.set(ssc.env) try { val timeTaken = job.run() - logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format( + println("Total delay: %.5f s for job %s (execution: %.5f s)".format( (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0)) } catch { case e: Exception => -- cgit v1.2.3 From b7ad291ac52896af6cb1d882392f3d6fa0cf3b49 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 07:08:07 +0000 Subject: Tuning Akka for more connections --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + .../src/main/scala/spark/streaming/examples/WordCount2.scala | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 57d212e4ca..fd64e224d7 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -31,6 +31,7 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s + akka.remote.netty.execution-pool-size = 10 """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index aa542ba07d..8561e7f079 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -62,10 +62,10 @@ object WordCount2_ExtraFunctions { object WordCount2 { def warmup(sc: SparkContext) { - (0 until 10).foreach {i => - sc.parallelize(1 to 20000000, 1000) + (0 until 3).foreach {i => + sc.parallelize(1 to 20000000, 500) .map(x => (x % 337, x % 1331)) - .reduceByKey(_ + _) + .reduceByKey(_ + _, 100) .count() } } @@ -84,11 +84,11 @@ object WordCount2 { val ssc = new StreamingContext(master, "WordCount2") ssc.setBatchDuration(batchDuration) - //warmup(ssc.sc) + warmup(ssc.sc) val data = ssc.sc.textFile(file, mapTasks.toInt).persist( new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas - println("Data count: " + data.count()) + println("Data count: " + data.map(x => if (x == "") 1 else x.split(" ").size / x.split(" ").size).count()) println("Data count: " + data.count()) println("Data count: " + data.count()) -- cgit v1.2.3 From 75487b2f5a6abedd322520f759b814ec643aea01 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 08:14:50 +0000 Subject: Broadcast the JobConf in HadoopRDD to reduce task sizes --- core/src/main/scala/spark/HadoopRDD.scala | 5 +++-- core/src/main/scala/spark/KryoSerializer.scala | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/HadoopRDD.scala b/core/src/main/scala/spark/HadoopRDD.scala index f282a4023b..0befca582d 100644 --- a/core/src/main/scala/spark/HadoopRDD.scala +++ b/core/src/main/scala/spark/HadoopRDD.scala @@ -42,7 +42,8 @@ class HadoopRDD[K, V]( minSplits: Int) extends RDD[(K, V)](sc) { - val serializableConf = new SerializableWritable(conf) + // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it + val confBroadcast = sc.broadcast(new SerializableWritable(conf)) @transient val splits_ : Array[Split] = { @@ -66,7 +67,7 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopSplit] var reader: RecordReader[K, V] = null - val conf = serializableConf.value + val conf = confBroadcast.value.value val fmt = createInputFormat(conf) reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 65d0532bd5..3d042b2f11 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -10,6 +10,7 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} import com.esotericsoftware.kryo.serialize.ClassSerializer +import com.esotericsoftware.kryo.serialize.SerializableSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport import spark.storage._ @@ -203,6 +204,9 @@ class KryoSerializer extends Serializer with Logging { kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) kryo.setRegistrationOptional(true) + // Allow sending SerializableWritable + kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) + // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. -- cgit v1.2.3 From efc7668d16b2a58f8d074c1cdaeae4b37dae1c9c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 08:22:57 +0000 Subject: Allow serializing HttpBroadcast through Kryo --- core/src/main/scala/spark/KryoSerializer.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 3d042b2f11..8a3f565071 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -13,6 +13,7 @@ import com.esotericsoftware.kryo.serialize.ClassSerializer import com.esotericsoftware.kryo.serialize.SerializableSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport +import spark.broadcast._ import spark.storage._ /** @@ -206,6 +207,7 @@ class KryoSerializer extends Serializer with Logging { // Allow sending SerializableWritable kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) + kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer()) // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we -- cgit v1.2.3 From 3fa0d7f0c9883ab77e89b7bcf70b7b11df9a4184 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 08:28:15 +0000 Subject: Serialize BlockRDD more efficiently --- core/src/main/scala/spark/BlockRDD.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/BlockRDD.scala index ea009f0f4f..daabc0d566 100644 --- a/core/src/main/scala/spark/BlockRDD.scala +++ b/core/src/main/scala/spark/BlockRDD.scala @@ -7,7 +7,8 @@ class BlockRDDSplit(val blockId: String, idx: Int) extends Split { } -class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { +class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) + extends RDD[T](sc) { @transient val splits_ = (0 until blockIds.size).map(i => { -- cgit v1.2.3 From 1d6b36d3c3698090b35d8e7c4f88cac410f9ea01 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 22:26:37 +0000 Subject: Further tuning for network performance --- core/src/main/scala/spark/storage/BlockMessage.scala | 14 +------------- core/src/main/scala/spark/util/AkkaUtils.scala | 3 ++- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index 0b2ed69e07..607633c6df 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -12,7 +12,7 @@ case class GetBlock(id: String) case class GotBlock(id: String, data: ByteBuffer) case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) -class BlockMessage() extends Logging{ +class BlockMessage() { // Un-initialized: typ = 0 // GetBlock: typ = 1 // GotBlock: typ = 2 @@ -22,8 +22,6 @@ class BlockMessage() extends Logging{ private var data: ByteBuffer = null private var level: StorageLevel = null - initLogging() - def set(getBlock: GetBlock) { typ = BlockMessage.TYPE_GET_BLOCK id = getBlock.id @@ -62,8 +60,6 @@ class BlockMessage() extends Logging{ } id = idBuilder.toString() - logDebug("Set from buffer Result: " + typ + " " + id) - logDebug("Buffer position is " + buffer.position) if (typ == BlockMessage.TYPE_PUT_BLOCK) { val booleanInt = buffer.getInt() @@ -77,23 +73,18 @@ class BlockMessage() extends Logging{ } data.put(buffer) data.flip() - logDebug("Set from buffer Result 2: " + level + " " + data) } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { val dataLength = buffer.getInt() - logDebug("Data length is "+ dataLength) - logDebug("Buffer position is " + buffer.position) data = ByteBuffer.allocate(dataLength) if (dataLength != buffer.remaining) { throw new Exception("Error parsing buffer") } data.put(buffer) data.flip() - logDebug("Set from buffer Result 3: " + data) } val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s") } def set(bufferMsg: BufferMessage) { @@ -145,8 +136,6 @@ class BlockMessage() extends Logging{ buffers += data } - logDebug("Start to log buffers.") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) /* println() println("BlockMessage: ") @@ -160,7 +149,6 @@ class BlockMessage() extends Logging{ println() */ val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s") return Message.createBufferMessage(buffers) } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index fd64e224d7..330bb42e59 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -31,7 +31,8 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s - akka.remote.netty.execution-pool-size = 10 + akka.remote.netty.execution-pool-size = 4 + akka.actor.default-dispatcher.throughput = 20 """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) -- cgit v1.2.3 From dc68febdce53efea43ee1ab91c05b14b7a5eae30 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 23:06:59 +0000 Subject: User Spark's closure serializer for the ShuffleMapTask cache --- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 9 ++++----- .../main/scala/spark/scheduler/cluster/ClusterScheduler.scala | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 73479bff01..f1eae9bc88 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -26,7 +26,8 @@ object ShuffleMapTask { return old } else { val out = new ByteArrayOutputStream - val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(dep) objOut.close() @@ -45,10 +46,8 @@ object ShuffleMapTask { } else { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]] val tuple = (rdd, dep) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 5b59479682..20c82ad0fa 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -115,6 +115,7 @@ class ClusterScheduler(sc: SparkContext) */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { synchronized { + SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { slaveIdToHost(o.slaveId) = o.hostname -- cgit v1.2.3 From 215544820fe70274d9dce1410f61e2052b8bc406 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Sep 2012 23:54:04 +0000 Subject: Serialize map output locations more efficiently, and only once, in MapOutputTracker --- core/src/main/scala/spark/MapOutputTracker.scala | 88 +++++++++++++++++++++--- 1 file changed, 80 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index de23eb6f48..cee2391c71 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -1,5 +1,6 @@ package spark +import java.io.{DataInputStream, DataOutputStream, ByteArrayOutputStream, ByteArrayInputStream} import java.util.concurrent.ConcurrentHashMap import akka.actor._ @@ -10,6 +11,7 @@ import akka.util.Duration import akka.util.Timeout import akka.util.duration._ +import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import spark.storage.BlockManagerId @@ -18,12 +20,11 @@ sealed trait MapOutputTrackerMessage case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage case object StopMapOutputTracker extends MapOutputTrackerMessage -class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]]) -extends Actor with Logging { +class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { def receive = { case GetMapOutputLocations(shuffleId: Int) => logInfo("Asked to get map output locations for shuffle " + shuffleId) - sender ! bmAddresses.get(shuffleId) + sender ! tracker.getSerializedLocations(shuffleId) case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") @@ -39,15 +40,19 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg val timeout = 10.seconds - private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] + var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. private var generation: Long = 0 private var generationLock = new java.lang.Object + // Cache a serialized version of the output locations for each shuffle to send them out faster + var cacheGeneration = generation + val cachedSerializedLocs = new HashMap[Int, Array[Byte]] + var trackerActor: ActorRef = if (isMaster) { - val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(bmAddresses)), name = actorName) + val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) logInfo("Registered MapOutputTrackerActor actor") actor } else { @@ -134,15 +139,16 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg } // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val fetched = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[BlockManagerId]] + val fetchedBytes = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedLocs = deserializeLocations(fetchedBytes) logInfo("Got the output locations") - bmAddresses.put(shuffleId, fetched) + bmAddresses.put(shuffleId, fetchedLocs) fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } - return fetched + return fetchedLocs } else { return locs } @@ -181,4 +187,70 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg } } } + + def getSerializedLocations(shuffleId: Int): Array[Byte] = { + var locs: Array[BlockManagerId] = null + var generationGotten: Long = -1 + generationLock.synchronized { + if (generation > cacheGeneration) { + cachedSerializedLocs.clear() + cacheGeneration = generation + } + cachedSerializedLocs.get(shuffleId) match { + case Some(bytes) => + return bytes + case None => + locs = bmAddresses.get(shuffleId) + generationGotten = generation + } + } + // If we got here, we failed to find the serialized locations in the cache, so we pulled + // out a snapshot of the locations as "locs"; let's serialize and return that + val bytes = serializeLocations(locs) + // Add them into the table only if the generation hasn't changed while we were working + generationLock.synchronized { + if (generation == generationGotten) { + cachedSerializedLocs(shuffleId) = bytes + } + } + return bytes + } + + // Serialize an array of map output locations into an efficient byte format so that we can send + // it to reduce tasks. We do this by grouping together the locations by block manager ID. + def serializeLocations(locs: Array[BlockManagerId]): Array[Byte] = { + val out = new ByteArrayOutputStream + val dataOut = new DataOutputStream(out) + dataOut.writeInt(locs.length) + val grouped = locs.zipWithIndex.groupBy(_._1) + dataOut.writeInt(grouped.size) + for ((id, pairs) <- grouped) { + dataOut.writeUTF(id.ip) + dataOut.writeInt(id.port) + dataOut.writeInt(pairs.length) + for ((_, blockIndex) <- pairs) { + dataOut.writeInt(blockIndex) + } + } + dataOut.close() + out.toByteArray + } + + // Opposite of serializeLocations. + def deserializeLocations(bytes: Array[Byte]): Array[BlockManagerId] = { + val dataIn = new DataInputStream(new ByteArrayInputStream(bytes)) + val length = dataIn.readInt() + val array = new Array[BlockManagerId](length) + val numGroups = dataIn.readInt() + for (i <- 0 until numGroups) { + val ip = dataIn.readUTF() + val port = dataIn.readInt() + val id = new BlockManagerId(ip, port) + val numBlocks = dataIn.readInt() + for (j <- 0 until numBlocks) { + array(dataIn.readInt()) = id + } + } + array + } } -- cgit v1.2.3 From 2fa6d999fd92cb7ce828278edcd09eecd1f458c1 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Sep 2012 00:16:39 +0000 Subject: Tuning Akka more --- core/src/main/scala/spark/util/AkkaUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 330bb42e59..df4e23bfd6 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -31,8 +31,8 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s - akka.remote.netty.execution-pool-size = 4 - akka.actor.default-dispatcher.throughput = 20 + akka.remote.netty.execution-pool-size = 8 + akka.actor.default-dispatcher.throughput = 30 """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) -- cgit v1.2.3 From 9ef90c95f4947e47f7c44f952ff8d294e0932a73 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Sep 2012 00:43:46 +0000 Subject: Bug fix --- core/src/main/scala/spark/MapOutputTracker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index cee2391c71..82c1391345 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -224,7 +224,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg dataOut.writeInt(locs.length) val grouped = locs.zipWithIndex.groupBy(_._1) dataOut.writeInt(grouped.size) - for ((id, pairs) <- grouped) { + for ((id, pairs) <- grouped if id != null) { dataOut.writeUTF(id.ip) dataOut.writeInt(id.port) dataOut.writeInt(pairs.length) -- cgit v1.2.3 From 019de4562c3c68ac36e6ab6a5577f5369336046b Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Sep 2012 02:50:41 +0000 Subject: Less warmup in word count --- streaming/src/main/scala/spark/streaming/examples/WordCount2.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 8561e7f079..c22949d7b9 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -84,7 +84,7 @@ object WordCount2 { val ssc = new StreamingContext(master, "WordCount2") ssc.setBatchDuration(batchDuration) - warmup(ssc.sc) + //warmup(ssc.sc) val data = ssc.sc.textFile(file, mapTasks.toInt).persist( new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas -- cgit v1.2.3 From babb7e3ce2a5eda793f87b42839cc20d14cb94cf Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 6 Sep 2012 05:28:29 -0700 Subject: Re-implemented ReducedWindowedDSteam to simplify and fix bugs. Added slice operator to DStream. Also, refactored DStream testsuites and added tests for reduceByKeyAndWindow. --- .../src/main/scala/spark/streaming/DStream.scala | 20 +- .../src/main/scala/spark/streaming/Interval.scala | 4 + .../spark/streaming/ReducedWindowedDStream.scala | 221 +++++++-------------- .../scala/spark/streaming/DStreamBasicSuite.scala | 67 +++++++ .../test/scala/spark/streaming/DStreamSuite.scala | 123 ------------ .../scala/spark/streaming/DStreamSuiteBase.scala | 68 +++++++ .../scala/spark/streaming/DStreamWindowSuite.scala | 107 ++++++++++ 7 files changed, 334 insertions(+), 276 deletions(-) create mode 100644 streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/DStreamSuite.scala create mode 100644 streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala create mode 100644 streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 20f1c4db20..50b9458fae 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -143,7 +143,7 @@ extends Logging with Serializable { /** * This method generates a SparkStreaming job for the given time - * and may require to be overriden by subclasses + * and may required to be overriden by subclasses */ def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { @@ -208,7 +208,7 @@ extends Logging with Serializable { new TransformedDStream(this, ssc.sc.clean(transformFunc)) } - def toQueue = { + def toBlockingQueue = { val queue = new ArrayBlockingQueue[RDD[T]](10000) this.foreachRDD(rdd => { queue.add(rdd) @@ -256,6 +256,22 @@ extends Logging with Serializable { def union(that: DStream[T]) = new UnifiedDStream(Array(this, that)) + def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { + + val rdds = new ArrayBuffer[RDD[T]]() + var time = toTime.floor(slideTime) + + while (time >= zeroTime && time >= fromTime) { + getOrCompute(time) match { + case Some(rdd) => rdds += rdd + case None => throw new Exception("Could not get old reduced RDD for time " + time) + } + time -= slideTime + } + + rdds.toSeq + } + def register() { ssc.registerOutputStream(this) } diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index 87b8437b3d..ffb7725ac9 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -9,6 +9,10 @@ case class Interval(beginTime: Time, endTime: Time) { new Interval(beginTime + time, endTime + time) } + def - (time: Time): Interval = { + new Interval(beginTime - time, endTime - time) + } + def < (that: Interval): Boolean = { if (this.duration != that.duration) { throw new Exception("Comparing two intervals with different durations [" + this + ", " + that + "]") diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 191d264b2b..b0beaba94d 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -12,7 +12,7 @@ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( - parent: DStream[(K, V)], + @transient parent: DStream[(K, V)], reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, _windowTime: Time, @@ -28,9 +28,7 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) - val allowPartialWindows = true - //reducedStream.persist(StorageLevel.MEMORY_ONLY_DESER_2) + @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner) override def dependencies = List(reducedStream) @@ -44,174 +42,95 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( checkpointInterval: Time): DStream[(K,V)] = { super.persist(storageLevel, checkpointLevel, checkpointInterval) reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval) + this } - + override def compute(validTime: Time): Option[RDD[(K, V)]] = { - - // Notation: + val currentTime = validTime + val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) + val previousWindow = currentWindow - slideTime + + logDebug("Window time = " + windowTime) + logDebug("Slide time = " + slideTime) + logDebug("ZeroTime = " + zeroTime) + logDebug("Current window = " + currentWindow) + logDebug("Previous window = " + previousWindow) + // _____________________________ - // | previous window _________|___________________ - // |___________________| current window | --------------> Time + // | previous window _________|___________________ + // |___________________| current window | --------------> Time // |_____________________________| - // + // // |________ _________| |________ _________| // | | // V V - // old time steps new time steps + // old RDDs new RDDs // - def getAdjustedWindow(endTime: Time, windowTime: Time): Interval = { - val beginTime = - if (allowPartialWindows && endTime - windowTime < parent.zeroTime) { - parent.zeroTime - } else { - endTime - windowTime - } - Interval(beginTime, endTime) - } - - val currentTime = validTime - val currentWindow = getAdjustedWindow(currentTime, windowTime) - val previousWindow = getAdjustedWindow(currentTime - slideTime, windowTime) - - logInfo("Current window = " + currentWindow) - logInfo("Slide time = " + slideTime) - logInfo("Previous window = " + previousWindow) - logInfo("Parent.zeroTime = " + parent.zeroTime) - - if (allowPartialWindows) { - if (currentTime - slideTime <= parent.zeroTime) { - reducedStream.getOrCompute(currentTime) match { - case Some(rdd) => return Some(rdd) - case None => throw new Exception("Could not get first reduced RDD for time " + currentTime) - } - } - } else { - if (previousWindow.beginTime < parent.zeroTime) { - if (currentWindow.beginTime < parent.zeroTime) { - return None - } else { - // If this is the first feasible window, then generate reduced value in the naive manner - val reducedRDDs = new ArrayBuffer[RDD[(K, V)]]() - var t = currentWindow.endTime - while (t > currentWindow.beginTime) { - reducedStream.getOrCompute(t) match { - case Some(rdd) => reducedRDDs += rdd - case None => throw new Exception("Could not get reduced RDD for time " + t) - } - t -= reducedStream.slideTime - } - if (reducedRDDs.size == 0) { - throw new Exception("Could not generate the first RDD for time " + validTime) - } - return Some(new UnionRDD(ssc.sc, reducedRDDs).reduceByKey(partitioner, reduceFunc)) - } - } - } - - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = getOrCompute(previousWindow.endTime) match { - case Some(rdd) => rdd.asInstanceOf[RDD[(_, _)]] - case None => throw new Exception("Could not get previous RDD for time " + previousWindow.endTime) - } - val oldRDDs = new ArrayBuffer[RDD[(_, _)]]() - val newRDDs = new ArrayBuffer[RDD[(_, _)]]() - // Get the RDDs of the reduced values in "old time steps" - var t = currentWindow.beginTime - while (t > previousWindow.beginTime) { - reducedStream.getOrCompute(t) match { - case Some(rdd) => oldRDDs += rdd.asInstanceOf[RDD[(_, _)]] - case None => throw new Exception("Could not get old reduced RDD for time " + t) - } - t -= reducedStream.slideTime - } + val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime) + logDebug("# old RDDs = " + oldRDDs.size) // Get the RDDs of the reduced values in "new time steps" - t = currentWindow.endTime - while (t > previousWindow.endTime) { - reducedStream.getOrCompute(t) match { - case Some(rdd) => newRDDs += rdd.asInstanceOf[RDD[(_, _)]] - case None => throw new Exception("Could not get new reduced RDD for time " + t) - } - t -= reducedStream.slideTime + val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime) + logDebug("# new RDDs = " + newRDDs.size) + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + + // Make the list of RDDs that needs to cogrouped together for reducing their reduced values + val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs + + // Cogroup the reduced RDDs and merge the reduced values + val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) + val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValuesFunc) + + Some(mergedValuesRDD) + } + + def mergeValues(numOldValues: Int, numNewValues: Int)(seqOfValues: Seq[Seq[V]]): V = { + + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") } - val allRDDs = new ArrayBuffer[RDD[(_, _)]]() - allRDDs += previousWindowRDD - allRDDs ++= oldRDDs - allRDDs ++= newRDDs - - - val numOldRDDs = oldRDDs.size - val numNewRDDs = newRDDs.size - logInfo("Generated numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) - logInfo("Generating CoGroupedRDD with " + allRDDs.size + " RDDs") - val newRDD = new CoGroupedRDD[K](allRDDs.toSeq, partitioner).asInstanceOf[RDD[(K,Seq[Seq[V]])]].map(x => { - val (key, value) = x - logDebug("value.size = " + value.size + ", numOldRDDs = " + numOldRDDs + ", numNewRDDs = " + numNewRDDs) - if (value.size != 1 + numOldRDDs + numNewRDDs) { - throw new Exception("Number of groups not odd!") - } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + + // Getting reduced values "new time steps" + val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - // old values = reduced values of the "old time steps" that are eliminated from current window - // new values = reduced values of the "new time steps" that are introduced to the current window - // previous value = reduced value of the previous window - - /*val numOldValues = (value.size - 1) / 2*/ - // Getting reduced values "old time steps" - val oldValues = - (0 until numOldRDDs).map(i => value(1 + i)).filter(_.size > 0).map(x => x(0)) - // Getting reduced values "new time steps" - val newValues = - (0 until numNewRDDs).map(i => value(1 + numOldRDDs + i)).filter(_.size > 0).map(x => x(0)) - - // If reduced value for the key does not exist in previous window, it should not exist in "old time steps" - if (value(0).size == 0 && oldValues.size != 0) { - throw new Exception("Unexpected: Key exists in old reduced values but not in previous reduced values") + if (seqOfValues(0).isEmpty) { + + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found") } - // For the key, at least one of "old time steps", "new time steps" and previous window should have reduced values - if (value(0).size == 0 && oldValues.size == 0 && newValues.size == 0) { - throw new Exception("Unexpected: Key does not exist in any of old, new, or previour reduced values") + // Reduce the new values + // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) + return newValues.reduce(reduceFunc) + } else { + + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + // println("old values = " + oldValues.map(_.toString).reduce(_ + " " + _)) + tempValue = invReduceFunc(tempValue, oldValues.reduce(reduceFunc)) } - // Logic to generate the final reduced value for current window: - // - // If previous window did not have reduced value for the key - // Then, return reduced value of "new time steps" as the final value - // Else, reduced value exists in previous window - // If "old" time steps did not have reduced value for the key - // Then, reduce previous window's reduced value with that of "new time steps" for final value - // Else, reduced values exists in "old time steps" - // If "new values" did not have reduced value for the key - // Then, inverse-reduce "old values" from previous window's reduced value for final value - // Else, all 3 values exist, combine all of them together - // - logDebug("# old values = " + oldValues.size + ", # new values = " + newValues) - val finalValue = { - if (value(0).size == 0) { - newValues.reduce(reduceFunc) - } else { - val prevValue = value(0)(0) - logDebug("prev value = " + prevValue) - if (oldValues.size == 0) { - // assuming newValue.size > 0 (all 3 cannot be zero, as checked earlier) - val temp = newValues.reduce(reduceFunc) - reduceFunc(prevValue, temp) - } else if (newValues.size == 0) { - invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) - } else { - val tempValue = invReduceFunc(prevValue, oldValues.reduce(reduceFunc)) - reduceFunc(tempValue, newValues.reduce(reduceFunc)) - } - } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) + tempValue = reduceFunc(tempValue, newValues.reduce(reduceFunc)) } - (key, finalValue) - }) - //newRDD.persist(StorageLevel.MEMORY_ONLY_DESER_2) - Some(newRDD) + // println("final value = " + tempValue) + return tempValue + } } } diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala new file mode 100644 index 0000000000..2634c9b405 --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -0,0 +1,67 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ +import scala.runtime.RichInt + +class DStreamBasicSuite extends DStreamSuiteBase { + + test("map-like operations") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + + // map + testOperation(input, (r: DStream[Int]) => r.map(_.toString), input.map(_.map(_.toString))) + + // flatMap + testOperation( + input, + (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)), + input.map(_.flatMap(x => Array(x, x * 2))) + ) + } + + test("shuffle-based operations") { + // reduceByKey + testOperation( + Seq(Seq("a", "a", "b"), Seq("", ""), Seq()), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + true + ) + + // reduce + testOperation( + Seq(1 to 4, 5 to 8, 9 to 12), + (s: DStream[Int]) => s.reduce(_ + _), + Seq(Seq(10), Seq(26), Seq(42)) + ) + } + + test("stateful operations") { + val inputData = + Seq( + Seq("a", "b", "c"), + Seq("a", "b", "c"), + Seq("a", "b", "c") + ) + + val outputData = + Seq( + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 2)), + Seq(("a", 3), ("b", 3), ("c", 3)) + ) + + val updateStateOp = (s: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: RichInt) => { + var newState = 0 + if (values != null) newState += values.reduce(_ + _) + if (state != null) newState += state.self + println("values = " + values + ", state = " + state + ", " + " new state = " + newState) + new RichInt(newState) + } + s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) + } + + testOperation(inputData, updateStateOp, outputData, true) + } +} diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamSuite.scala deleted file mode 100644 index fc00952afe..0000000000 --- a/streaming/src/test/scala/spark/streaming/DStreamSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -package spark.streaming - -import spark.Logging -import spark.streaming.StreamingContext._ -import spark.streaming.util.ManualClock - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import scala.collection.mutable.ArrayBuffer -import scala.runtime.RichInt - -class DStreamSuite extends FunSuite with BeforeAndAfter with Logging { - - var ssc: StreamingContext = null - val batchDurationMillis = 1000 - - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - - def testOp[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - useSet: Boolean = false - ) { - try { - ssc = new StreamingContext("local", "test") - ssc.setBatchDuration(Milliseconds(batchDurationMillis)) - - val inputStream = ssc.createQueueStream(input.map(ssc.sc.makeRDD(_, 2)).toIterator) - val outputStream = operation(inputStream) - val outputQueue = outputStream.toQueue - - ssc.start() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - clock.addToTime(input.size * batchDurationMillis) - - Thread.sleep(1000) - - val output = new ArrayBuffer[Seq[V]]() - while(outputQueue.size > 0) { - val rdd = outputQueue.take() - output += (rdd.collect()) - } - - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - if (useSet) { - assert(output(i).toSet === expectedOutput(i).toSet) - } else { - assert(output(i).toList === expectedOutput(i).toList) - } - } - } finally { - ssc.stop() - } - } - - test("map-like operations") { - val inputData = Seq(1 to 4, 5 to 8, 9 to 12) - - // map - testOp(inputData, (r: DStream[Int]) => r.map(_.toString), inputData.map(_.map(_.toString))) - - // flatMap - testOp( - inputData, - (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)), - inputData.map(_.flatMap(x => Array(x, x * 2))) - ) - } - - test("shuffle-based operations") { - // reduceByKey - testOp( - Seq(Seq("a", "a", "b"), Seq("", ""), Seq()), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), - true - ) - - // reduce - testOp( - Seq(1 to 4, 5 to 8, 9 to 12), - (s: DStream[Int]) => s.reduce(_ + _), - Seq(Seq(10), Seq(26), Seq(42)) - ) - } - - test("window-based operations") { - - } - - - test("stateful operations") { - val inputData = - Seq( - Seq("a", "b", "c"), - Seq("a", "b", "c"), - Seq("a", "b", "c") - ) - - val outputData = - Seq( - Seq(("a", 1), ("b", 1), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 2)), - Seq(("a", 3), ("b", 3), ("c", 3)) - ) - - val updateStateOp = (s: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: RichInt) => { - var newState = 0 - if (values != null) newState += values.reduce(_ + _) - if (state != null) newState += state.self - println("values = " + values + ", state = " + state + ", " + " new state = " + newState) - new RichInt(newState) - } - s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) - } - - testOp(inputData, updateStateOp, outputData, true) - } -} diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala new file mode 100644 index 0000000000..1c4ea14b1d --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala @@ -0,0 +1,68 @@ +package spark.streaming + +import spark.{RDD, Logging} +import util.ManualClock +import collection.mutable.ArrayBuffer +import org.scalatest.FunSuite +import scala.collection.mutable.Queue + + +trait DStreamSuiteBase extends FunSuite with Logging { + + def batchDuration() = Seconds(1) + + def maxWaitTimeMillis() = 10000 + + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { + + val manualClock = true + + if (manualClock) { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + } + + val ssc = new StreamingContext("local", "test") + + try { + ssc.setBatchDuration(Milliseconds(batchDuration)) + + val inputQueue = new Queue[RDD[U]]() + inputQueue ++= input.map(ssc.sc.makeRDD(_, 2)) + val emptyRDD = ssc.sc.makeRDD(Seq[U](), 2) + + val inputStream = ssc.createQueueStream(inputQueue, true, emptyRDD) + val outputStream = operation(inputStream) + + val output = new ArrayBuffer[Seq[V]]() + outputStream.foreachRDD(rdd => output += rdd.collect()) + + ssc.start() + + val clock = ssc.scheduler.clock + if (clock.isInstanceOf[ManualClock]) { + clock.asInstanceOf[ManualClock].addToTime(input.size * batchDuration.milliseconds) + } + + val startTime = System.currentTimeMillis() + while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + Thread.sleep(500) + } + + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + if (useSet) { + assert(output(i).toSet === expectedOutput(i).toSet) + } else { + assert(output(i).toList === expectedOutput(i).toList) + } + } + } finally { + ssc.stop() + } + } +} diff --git a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala new file mode 100644 index 0000000000..c0e054418c --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala @@ -0,0 +1,107 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ + +class DStreamWindowSuite extends DStreamSuiteBase { + + def testReduceByKeyAndWindow( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowTime: Time = Seconds(2), + slideTime: Time = Seconds(1) + ) { + test("reduceByKeyAndWindow - " + name) { + testOperation( + input, + (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist(), + expectedOutput, + true + ) + } + } + + testReduceByKeyAndWindow( + "basic reduction", + Seq(Seq(("a", 1), ("a", 3)) ), + Seq(Seq(("a", 4)) ) + ) + + testReduceByKeyAndWindow( + "key already in window and new value added into window", + Seq( Seq(("a", 1)), Seq(("a", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2)) ) + ) + + testReduceByKeyAndWindow( + "new key added into window", + Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) + ) + + testReduceByKeyAndWindow( + "key removed from window", + Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), + Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) + ) + + val largerSlideInput = Seq( + Seq(("a", 1)), // 1st window from here + Seq(("a", 2)), + Seq(("a", 3)), // 2nd window from here + Seq(("a", 4)), + Seq(("a", 5)), // 3rd window from here + Seq(("a", 6)), + Seq(), // 4th window from here + Seq(), + Seq() // 5th window from here + ) + + val largerSlideOutput = Seq( + Seq(("a", 1)), + Seq(("a", 6)), + Seq(("a", 14)), + Seq(("a", 15)), + Seq(("a", 6)) + ) + + testReduceByKeyAndWindow( + "larger slide time", + largerSlideInput, + largerSlideOutput, + Seconds(4), + Seconds(2) + ) + + val bigInput = Seq( + Seq(("a", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1)), + Seq(), + Seq(("a", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1)), + Seq() + ) + + val bigOutput = Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)) + ) + + testReduceByKeyAndWindow("big test", bigInput, bigOutput) +} -- cgit v1.2.3 From 4a7bde6865cf22af060f20a9619c516b811c80f2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 6 Sep 2012 19:06:59 -0700 Subject: Fixed bugs and added testcases for naive reduceByKeyAndWindow. --- .../src/main/scala/spark/streaming/DStream.scala | 6 + .../src/main/scala/spark/streaming/Scheduler.scala | 2 +- .../scala/spark/streaming/WindowedDStream.scala | 38 +---- .../scala/spark/streaming/DStreamBasicSuite.scala | 2 +- .../scala/spark/streaming/DStreamWindowSuite.scala | 179 +++++++++++++++------ 5 files changed, 140 insertions(+), 87 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 50b9458fae..3973ca1520 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -256,11 +256,17 @@ extends Logging with Serializable { def union(that: DStream[T]) = new UnifiedDStream(Array(this, that)) + def slice(interval: Interval): Seq[RDD[T]] = { + slice(interval.beginTime, interval.endTime) + } + + // Get all the RDDs between fromTime to toTime (both included) def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() var time = toTime.floor(slideTime) + while (time >= zeroTime && time >= fromTime) { getOrCompute(time) match { case Some(rdd) => rdds += rdd diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 12e52bf56c..00136685d5 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -40,7 +40,7 @@ extends Logging { } def generateRDDs (time: Time) { - println("\n-----------------------------------------------------\n") + logInfo("\n-----------------------------------------------------\n") logInfo("Generating RDDs for time " + time) outputStreams.foreach(outputStream => { outputStream.generateJob(time) match { diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala index 6c791fcfc1..93c1291691 100644 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -1,12 +1,8 @@ package spark.streaming -import spark.streaming.StreamingContext._ - import spark.RDD import spark.UnionRDD -import spark.SparkContext._ -import scala.collection.mutable.ArrayBuffer class WindowedDStream[T: ClassManifest]( parent: DStream[T], @@ -22,8 +18,6 @@ class WindowedDStream[T: ClassManifest]( throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - val allowPartialWindows = true - override def dependencies = List(parent) def windowTime: Time = _windowTime @@ -31,36 +25,8 @@ class WindowedDStream[T: ClassManifest]( override def slideTime: Time = _slideTime override def compute(validTime: Time): Option[RDD[T]] = { - val parentRDDs = new ArrayBuffer[RDD[T]]() - val windowEndTime = validTime - val windowStartTime = if (allowPartialWindows && windowEndTime - windowTime < parent.zeroTime) { - parent.zeroTime - } else { - windowEndTime - windowTime - } - - logInfo("Window = " + windowStartTime + " - " + windowEndTime) - logInfo("Parent.zeroTime = " + parent.zeroTime) - - if (windowStartTime >= parent.zeroTime) { - // Walk back through time, from the 'windowEndTime' to 'windowStartTime' - // and get all parent RDDs from the parent DStream - var t = windowEndTime - while (t > windowStartTime) { - parent.getOrCompute(t) match { - case Some(rdd) => parentRDDs += rdd - case None => throw new Exception("Could not generate parent RDD for time " + t) - } - t -= parent.slideTime - } - } - - // Do a union of all parent RDDs to generate the new RDD - if (parentRDDs.size > 0) { - Some(new UnionRDD(ssc.sc, parentRDDs)) - } else { - None - } + val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) + Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) } } diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala index 2634c9b405..9b953d9dae 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -56,7 +56,7 @@ class DStreamBasicSuite extends DStreamSuiteBase { var newState = 0 if (values != null) newState += values.reduce(_ + _) if (state != null) newState += state.self - println("values = " + values + ", state = " + state + ", " + " new state = " + newState) + //println("values = " + values + ", state = " + state + ", " + " new state = " + newState) new RichInt(newState) } s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) diff --git a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala index c0e054418c..061cab2cbb 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala @@ -4,47 +4,6 @@ import spark.streaming.StreamingContext._ class DStreamWindowSuite extends DStreamSuiteBase { - def testReduceByKeyAndWindow( - name: String, - input: Seq[Seq[(String, Int)]], - expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = Seconds(2), - slideTime: Time = Seconds(1) - ) { - test("reduceByKeyAndWindow - " + name) { - testOperation( - input, - (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist(), - expectedOutput, - true - ) - } - } - - testReduceByKeyAndWindow( - "basic reduction", - Seq(Seq(("a", 1), ("a", 3)) ), - Seq(Seq(("a", 4)) ) - ) - - testReduceByKeyAndWindow( - "key already in window and new value added into window", - Seq( Seq(("a", 1)), Seq(("a", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2)) ) - ) - - testReduceByKeyAndWindow( - "new key added into window", - Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) - ) - - testReduceByKeyAndWindow( - "key removed from window", - Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), - Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) - ) - val largerSlideInput = Seq( Seq(("a", 1)), // 1st window from here Seq(("a", 2)), @@ -65,14 +24,6 @@ class DStreamWindowSuite extends DStreamSuiteBase { Seq(("a", 6)) ) - testReduceByKeyAndWindow( - "larger slide time", - largerSlideInput, - largerSlideOutput, - Seconds(4), - Seconds(2) - ) - val bigInput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)), @@ -89,6 +40,29 @@ class DStreamWindowSuite extends DStreamSuiteBase { ) val bigOutput = Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 1)), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 1)) + ) + + /* + The output of the reduceByKeyAndWindow with inverse reduce function is + difference from the naive reduceByKeyAndWindow. Even if the count of a + particular key is 0, the key does not get eliminated from the RDDs of + ReducedWindowedDStream. This causes the number of keys in these RDDs to + increase forever. A more generalized version that allows elimination of + keys should be considered. + */ + val bigOutputInv = Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)), Seq(("a", 2), ("b", 2), ("c", 1)), @@ -103,5 +77,112 @@ class DStreamWindowSuite extends DStreamSuiteBase { Seq(("a", 1), ("b", 0), ("c", 0)) ) + def testReduceByKeyAndWindow( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowTime: Time = Seconds(2), + slideTime: Time = Seconds(1) + ) { + test("reduceByKeyAndWindow - " + name) { + testOperation( + input, + (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist(), + expectedOutput, + true + ) + } + } + + def testReduceByKeyAndWindowInv( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowTime: Time = Seconds(2), + slideTime: Time = Seconds(1) + ) { + test("reduceByKeyAndWindowInv - " + name) { + testOperation( + input, + (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist(), + expectedOutput, + true + ) + } + } + + + // Testing naive reduceByKeyAndWindow (without invertible function) + + testReduceByKeyAndWindow( + "basic reduction", + Seq(Seq(("a", 1), ("a", 3)) ), + Seq(Seq(("a", 4)) ) + ) + + testReduceByKeyAndWindow( + "key already in window and new value added into window", + Seq( Seq(("a", 1)), Seq(("a", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2)) ) + ) + + + testReduceByKeyAndWindow( + "new key added into window", + Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) + ) + + testReduceByKeyAndWindow( + "key removed from window", + Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), + Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq() ) + ) + + testReduceByKeyAndWindow( + "larger slide time", + largerSlideInput, + largerSlideOutput, + Seconds(4), + Seconds(2) + ) + testReduceByKeyAndWindow("big test", bigInput, bigOutput) + + + // Testing reduceByKeyAndWindow (with invertible reduce function) + + testReduceByKeyAndWindowInv( + "basic reduction", + Seq(Seq(("a", 1), ("a", 3)) ), + Seq(Seq(("a", 4)) ) + ) + + testReduceByKeyAndWindowInv( + "key already in window and new value added into window", + Seq( Seq(("a", 1)), Seq(("a", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2)) ) + ) + + testReduceByKeyAndWindowInv( + "new key added into window", + Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) + ) + + testReduceByKeyAndWindowInv( + "key removed from window", + Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), + Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) + ) + + testReduceByKeyAndWindowInv( + "larger slide time", + largerSlideInput, + largerSlideOutput, + Seconds(4), + Seconds(2) + ) + + testReduceByKeyAndWindowInv("big test", bigInput, bigOutputInv) } -- cgit v1.2.3 From db08a362aae68682f9105f9e5568bc9b9d9faaab Mon Sep 17 00:00:00 2001 From: haoyuan Date: Fri, 7 Sep 2012 02:17:52 +0000 Subject: commit opt for grep scalibility test. --- .../main/scala/spark/storage/BlockManager.scala | 7 +++- .../spark/streaming/NetworkInputTracker.scala | 40 ++++++++++++---------- .../scala/spark/streaming/RawInputDStream.scala | 17 +++++---- .../spark/streaming/examples/WordCountRaw.scala | 19 +++++++--- 4 files changed, 51 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index f2d9499bad..4cdb9710ec 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -509,10 +509,15 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Replicate block to another node. */ + var firstTime = true + var peers : Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - var peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + if (firstTime) { + peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + firstTime = false; + } for (peer: BlockManagerId <- peers) { val start = System.nanoTime data.rewind() diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index acf97c1883..9f9001e4d5 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -4,6 +4,7 @@ import spark.Logging import spark.SparkEnv import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue import akka.actor._ import akka.pattern.ask @@ -28,6 +29,17 @@ extends Logging { logInfo("Registered receiver for network stream " + streamId) sender ! true } + case GotBlockIds(streamId, blockIds) => { + val tmp = receivedBlockIds.synchronized { + if (!receivedBlockIds.contains(streamId)) { + receivedBlockIds += ((streamId, new Queue[String])) + } + receivedBlockIds(streamId) + } + tmp.synchronized { + tmp ++= blockIds + } + } } } @@ -69,8 +81,8 @@ extends Logging { val networkInputStreamIds = networkInputStreams.map(_.id).toArray val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Array[String]] - val timeout = 1000.milliseconds + val receivedBlockIds = new HashMap[Int, Queue[String]] + val timeout = 5000.milliseconds var currentTime: Time = null @@ -86,22 +98,12 @@ extends Logging { } def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { - if (currentTime == null || time > currentTime) { - logInfo("Getting block ids from receivers for " + time) - implicit val ec = ssc.env.actorSystem.dispatcher - receivedBlockIds.clear() - val message = new GetBlockIds(time) - val listOfFutures = receiverInfo.values.map( - _.ask(message)(timeout).mapTo[GotBlockIds] - ).toList - val futureOfList = Future.sequence(listOfFutures) - val allBlockIds = Await.result(futureOfList, timeout) - receivedBlockIds ++= allBlockIds.map(x => (x.streamId, x.blocksIds)) - if (receivedBlockIds.size != receiverInfo.size) { - throw new Exception("Unexpected number of the Block IDs received") - } - currentTime = time + val queue = receivedBlockIds.synchronized { + receivedBlockIds.getOrElse(receiverId, new Queue[String]()) + } + val result = queue.synchronized { + queue.dequeueAll(x => true) } - receivedBlockIds.getOrElse(receiverId, Array[String]()) + result.toArray } -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index d59c245a23..d29aea7886 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -86,14 +86,15 @@ class RawInputDStream[T: ClassManifest]( private class ReceiverActor(env: SparkEnv, receivingThread: Thread) extends Actor { val newBlocks = new ArrayBuffer[String] + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 5.seconds + override def preStart() { - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 1.seconds val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) Await.result(future, timeout) } @@ -101,6 +102,7 @@ class RawInputDStream[T: ClassManifest]( override def receive = { case BlockPublished(blockId) => newBlocks += blockId + val future = trackerActor ! GotBlockIds(streamId, Array(blockId)) case GetBlockIds(time) => logInfo("Got request for block IDs for " + time) @@ -111,5 +113,6 @@ class RawInputDStream[T: ClassManifest]( receivingThread.interrupt() sender ! true } + } } diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 298d9ef381..9702003805 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -1,11 +1,24 @@ package spark.streaming.examples import spark.util.IntParam +import spark.SparkContext +import spark.SparkContext._ import spark.storage.StorageLevel import spark.streaming._ import spark.streaming.StreamingContext._ +import WordCount2_ExtraFunctions._ + object WordCountRaw { + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + def main(args: Array[String]) { if (args.length != 7) { System.err.println("Usage: WordCountRaw ") @@ -20,16 +33,12 @@ object WordCountRaw { ssc.setBatchDuration(Milliseconds(batchMs)) // Make sure some tasks have started on each node - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() + moreWarmup(ssc.sc) val rawStreams = (1 to streams).map(_ => ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnifiedDStream(rawStreams) - import WordCount2_ExtraFunctions._ - val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, -- cgit v1.2.3 From 381e2c7ac4cba952de2bf8b0090ff0799829cf30 Mon Sep 17 00:00:00 2001 From: haoyuan Date: Thu, 6 Sep 2012 20:54:52 -0700 Subject: add warmup code for TopKWordCountRaw.scala --- .../spark/streaming/examples/TopKWordCountRaw.scala | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index be3188c5ed..3ba07d0448 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -1,11 +1,24 @@ package spark.streaming.examples import spark.util.IntParam +import spark.SparkContext +import spark.SparkContext._ import spark.storage.StorageLevel import spark.streaming._ import spark.streaming.StreamingContext._ +import WordCount2_ExtraFunctions._ + object TopKWordCountRaw { + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + def main(args: Array[String]) { if (args.length != 7) { System.err.println("Usage: TopKWordCountRaw ") @@ -20,16 +33,12 @@ object TopKWordCountRaw { ssc.setBatchDuration(Milliseconds(batchMs)) // Make sure some tasks have started on each node - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() + moreWarmup(ssc.sc) val rawStreams = (1 to streams).map(_ => ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnifiedDStream(rawStreams) - import WordCount2_ExtraFunctions._ - val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, -- cgit v1.2.3 From b5750726ff3306e0ea5741141f6cae0eb3449902 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 7 Sep 2012 20:16:21 +0000 Subject: Fixed bugs in streaming Scheduler and optimized QueueInputDStream. --- core/src/main/scala/spark/RDD.scala | 2 +- streaming/src/main/scala/spark/streaming/QueueInputDStream.scala | 8 ++++++-- streaming/src/main/scala/spark/streaming/Scheduler.scala | 2 +- streaming/src/main/scala/spark/streaming/StateDStream.scala | 2 +- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3fe8e8a4bf..d28f3593fe 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -94,7 +94,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def getStorageLevel = storageLevel - def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = { + def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") } diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala index bab48ff954..f6b53fe2f2 100644 --- a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala @@ -25,7 +25,11 @@ class QueueInputDStream[T: ClassManifest]( buffer ++= queue } if (buffer.size > 0) { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) + if (oneAtATime) { + Some(buffer.first) + } else { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } } else if (defaultRDD != null) { Some(defaultRDD) } else { @@ -33,4 +37,4 @@ class QueueInputDStream[T: ClassManifest]( } } -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 12e52bf56c..b4b8e34ec8 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -26,7 +26,6 @@ extends Logging { val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_)) def start() { - val zeroTime = Time(timer.start()) outputStreams.foreach(_.initialize(zeroTime)) inputStreams.par.foreach(_.start()) @@ -41,6 +40,7 @@ extends Logging { def generateRDDs (time: Time) { println("\n-----------------------------------------------------\n") + SparkEnv.set(ssc.env) logInfo("Generating RDDs for time " + time) outputStreams.foreach(outputStream => { outputStream.generateJob(time) match { diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index f313d8c162..9d3561b4a0 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -7,7 +7,7 @@ import spark.SparkContext._ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( - parent: DStream[(K, V)], + @transient parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean -- cgit v1.2.3 From c63a6064584ea19d62e0abcbd3886d7b1e429ea1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 10 Sep 2012 19:51:27 +0000 Subject: Made NewHadoopRDD broadcast its job configuration (same as HadoopRDD). --- core/src/main/scala/spark/NewHadoopRDD.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/NewHadoopRDD.scala b/core/src/main/scala/spark/NewHadoopRDD.scala index d024d38aa9..14f708a3f8 100644 --- a/core/src/main/scala/spark/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/NewHadoopRDD.scala @@ -28,7 +28,9 @@ class NewHadoopRDD[K, V]( @transient conf: Configuration) extends RDD[(K, V)](sc) { - private val serializableConf = new SerializableWritable(conf) + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + // private val serializableConf = new SerializableWritable(conf) private val jobtrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") @@ -41,7 +43,7 @@ class NewHadoopRDD[K, V]( @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance - val jobContext = new JobContext(serializableConf.value, jobId) + val jobContext = new JobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Split](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -54,9 +56,9 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] - val conf = serializableConf.value + val conf = confBroadcast.value.value val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) - val context = new TaskAttemptContext(serializableConf.value, attemptId) + val context = new TaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) reader.initialize(split.serializableHadoopSplit.value, context) -- cgit v1.2.3 From 3cbc72ff1dc660a835c032356ba7b57883c5df5e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 14 Sep 2012 07:00:30 +0000 Subject: Minor tweaks --- .../src/main/scala/spark/streaming/DStream.scala | 42 +++++++++++----------- .../main/scala/spark/streaming/StateDStream.scala | 4 +-- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 3973ca1520..7e8098c346 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -284,9 +284,8 @@ extends Logging with Serializable { } -abstract class InputDStream[T: ClassManifest] ( - @transient ssc: StreamingContext) -extends DStream[T](ssc) { +abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext) + extends DStream[T](ssc) { override def dependencies = List() @@ -303,9 +302,9 @@ extends DStream[T](ssc) { */ class MappedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], - mapFunc: T => U) -extends DStream[U](parent.ssc) { + @transient parent: DStream[T], + mapFunc: T => U + ) extends DStream[U](parent.ssc) { override def dependencies = List(parent) @@ -322,9 +321,9 @@ extends DStream[U](parent.ssc) { */ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - flatMapFunc: T => Traversable[U]) -extends DStream[U](parent.ssc) { + @transient parent: DStream[T], + flatMapFunc: T => Traversable[U] + ) extends DStream[U](parent.ssc) { override def dependencies = List(parent) @@ -340,8 +339,10 @@ extends DStream[U](parent.ssc) { * TODO */ -class FilteredDStream[T: ClassManifest](parent: DStream[T], filterFunc: T => Boolean) -extends DStream[T](parent.ssc) { +class FilteredDStream[T: ClassManifest]( + @transient parent: DStream[T], + filterFunc: T => Boolean + ) extends DStream[T](parent.ssc) { override def dependencies = List(parent) @@ -358,9 +359,9 @@ extends DStream[T](parent.ssc) { */ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - mapPartFunc: Iterator[T] => Iterator[U]) -extends DStream[U](parent.ssc) { + @transient parent: DStream[T], + mapPartFunc: Iterator[T] => Iterator[U] + ) extends DStream[U](parent.ssc) { override def dependencies = List(parent) @@ -376,7 +377,8 @@ extends DStream[U](parent.ssc) { * TODO */ -class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { +class GlommedDStream[T: ClassManifest](@transient parent: DStream[T]) + extends DStream[Array[T]](parent.ssc) { override def dependencies = List(parent) @@ -393,7 +395,7 @@ class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array */ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - parent: DStream[(K,V)], + @transient parent: DStream[(K,V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, @@ -418,7 +420,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( * TODO */ -class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) +class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]]) extends DStream[T](parents(0).ssc) { if (parents.length == 0) { @@ -457,7 +459,7 @@ class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) */ class PerElementForEachDStream[T: ClassManifest] ( - parent: DStream[T], + @transient parent: DStream[T], foreachFunc: T => Unit ) extends DStream[Unit](parent.ssc) { @@ -488,7 +490,7 @@ class PerElementForEachDStream[T: ClassManifest] ( */ class PerRDDForEachDStream[T: ClassManifest] ( - parent: DStream[T], + @transient parent: DStream[T], foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { @@ -516,7 +518,7 @@ class PerRDDForEachDStream[T: ClassManifest] ( */ class TransformedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], + @transient parent: DStream[T], transformFunc: (RDD[T], Time) => RDD[U] ) extends DStream[U](parent.ssc) { diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index 72b71d5fab..c40f70c91d 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -26,14 +26,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { generatedRDDs.get(time) match { case Some(oldRDD) => { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { + if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { val r = oldRDD val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { override val partitioner = oldRDD.partitioner } generatedRDDs.update(time, checkpointedRDD) - logInfo("Updated RDD of time " + time + " with its checkpointed version") + logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id) Some(checkpointedRDD) } else { Some(oldRDD) -- cgit v1.2.3 From 9abdfa663360252d2edb346e6b3df4ff94ce78d7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 17 Sep 2012 00:08:50 -0700 Subject: Fix Python 2.6 compatibility in Python API. --- pyspark/pyspark/rdd.py | 17 +++++++++++------ python/tc.py | 22 ---------------------- 2 files changed, 11 insertions(+), 28 deletions(-) delete mode 100644 python/tc.py diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 8477f6dd02..e2137fe06c 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,5 +1,5 @@ from base64 import standard_b64encode as b64enc -from collections import Counter +from collections import defaultdict from itertools import chain, ifilter, imap import shlex from subprocess import Popen, PIPE @@ -198,13 +198,18 @@ class RDD(object): def countByValue(self): """ - >>> sc.parallelize([1, 2, 1, 2, 2]).countByValue().most_common() - [(2, 3), (1, 2)] + >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items()) + [(1, 2), (2, 3)] """ def countPartition(iterator): - yield Counter(iterator) + counts = defaultdict(int) + for obj in iterator: + counts[obj] += 1 + yield counts def mergeMaps(m1, m2): - return m1 + m2 + for (k, v) in m2.iteritems(): + m1[k] += v + return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) def take(self, num): @@ -271,7 +276,7 @@ class RDD(object): def countByKey(self): """ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> rdd.countByKey().most_common() + >>> sorted(rdd.countByKey().items()) [('a', 2), ('b', 1)] """ return self.map(lambda x: x[0]).countByValue() diff --git a/python/tc.py b/python/tc.py deleted file mode 100644 index 5dcc4317e0..0000000000 --- a/python/tc.py +++ /dev/null @@ -1,22 +0,0 @@ -from rdd import SparkContext - -sc = SparkContext("local", "PythonWordCount") -e = [(1, 2), (2, 3), (4, 1)] - -tc = sc.parallelizePairs(e) - -edges = tc.mapPairs(lambda (x, y): (y, x)) - -oldCount = 0 -nextCount = tc.count() - -def project(x): - return (x[1][1], x[1][0]) - -while nextCount != oldCount: - oldCount = nextCount - tc = tc.union(tc.join(edges).mapPairs(project)).distinct() - nextCount = tc.count() - -print "TC has %i edges" % tc.count() -print tc.collect() -- cgit v1.2.3 From 86d420478f711e0f4eccc64c238efddf030a9b0f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 17 Sep 2012 14:25:48 -0700 Subject: Allowed StreamingContext to be created from existing SparkContext --- .../src/main/scala/spark/streaming/StreamingContext.scala | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index cb0f9ceb15..12f3626680 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -21,16 +21,13 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -class StreamingContext ( - master: String, - frameworkName: String, - val sparkHome: String = null, - val jars: Seq[String] = Nil) - extends Logging { - +class StreamingContext (@transient val sc: SparkContext) extends Logging { + + def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = + this(new SparkContext(master, frameworkName, sparkHome, jars)) + initLogging() - val sc = new SparkContext(master, frameworkName, sparkHome, jars) val env = SparkEnv.get val inputStreams = new ArrayBuffer[InputDStream[_]]() -- cgit v1.2.3 From e95ff45b53bf995d89f1825b9581cc18a083a438 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 13 Oct 2012 20:10:49 -0700 Subject: Implemented checkpointing of StreamingContext and DStream graph. --- core/src/main/scala/spark/SparkContext.scala | 4 +- .../main/scala/spark/streaming/Checkpoint.scala | 92 +++++++++++++++ .../src/main/scala/spark/streaming/DStream.scala | 123 ++++++++++++++------- .../main/scala/spark/streaming/DStreamGraph.scala | 80 ++++++++++++++ .../scala/spark/streaming/FileInputDStream.scala | 59 ++++++---- .../spark/streaming/ReducedWindowedDStream.scala | 80 +++++++------- .../src/main/scala/spark/streaming/Scheduler.scala | 33 +++--- .../main/scala/spark/streaming/StateDStream.scala | 20 ++-- .../scala/spark/streaming/StreamingContext.scala | 109 ++++++++++++------ .../examples/FileStreamWithCheckpoint.scala | 76 +++++++++++++ .../scala/spark/streaming/examples/Grep2.scala | 2 +- .../spark/streaming/examples/WordCount2.scala | 2 +- .../scala/spark/streaming/examples/WordMax2.scala | 2 +- .../spark/streaming/util/RecurringTimer.scala | 19 +++- 14 files changed, 536 insertions(+), 165 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/Checkpoint.scala create mode 100644 streaming/src/main/scala/spark/streaming/DStreamGraph.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bebebe8262..1d5131ad13 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -46,8 +46,8 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend import spark.storage.BlockManagerMaster class SparkContext( - master: String, - frameworkName: String, + val master: String, + val frameworkName: String, val sparkHome: String, val jars: Seq[String]) extends Logging { diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala new file mode 100644 index 0000000000..3bd8fd5a27 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -0,0 +1,92 @@ +package spark.streaming + +import spark.Utils + +import org.apache.hadoop.fs.{FileUtil, Path} +import org.apache.hadoop.conf.Configuration + +import java.io.{ObjectInputStream, ObjectOutputStream} + +class Checkpoint(@transient ssc: StreamingContext) extends Serializable { + val master = ssc.sc.master + val frameworkName = ssc.sc.frameworkName + val sparkHome = ssc.sc.sparkHome + val jars = ssc.sc.jars + val graph = ssc.graph + val batchDuration = ssc.batchDuration + val checkpointFile = ssc.checkpointFile + val checkpointInterval = ssc.checkpointInterval + + def saveToFile(file: String) { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (fs.exists(path)) { + val bkPath = new Path(path.getParent, path.getName + ".bk") + FileUtil.copy(fs, path, fs, bkPath, true, true, conf) + println("Moved existing checkpoint file to " + bkPath) + } + val fos = fs.create(path) + val oos = new ObjectOutputStream(fos) + oos.writeObject(this) + oos.close() + fs.close() + } + + def toBytes(): Array[Byte] = { + val cp = new Checkpoint(ssc) + val bytes = Utils.serialize(cp) + bytes + } +} + +object Checkpoint { + + def loadFromFile(file: String): Checkpoint = { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (!fs.exists(path)) { + throw new Exception("Could not read checkpoint file " + path) + } + val fis = fs.open(path) + val ois = new ObjectInputStream(fis) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp + } + + def fromBytes(bytes: Array[Byte]): Checkpoint = { + Utils.deserialize[Checkpoint](bytes) + } + + /*def toBytes(ssc: StreamingContext): Array[Byte] = { + val cp = new Checkpoint(ssc) + val bytes = Utils.serialize(cp) + bytes + } + + + def saveContext(ssc: StreamingContext, file: String) { + val cp = new Checkpoint(ssc) + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (fs.exists(path)) { + val bkPath = new Path(path.getParent, path.getName + ".bk") + FileUtil.copy(fs, path, fs, bkPath, true, true, conf) + println("Moved existing checkpoint file to " + bkPath) + } + val fos = fs.create(path) + val oos = new ObjectOutputStream(fos) + oos.writeObject(cp) + oos.close() + fs.close() + } + + def loadContext(file: String): StreamingContext = { + loadCheckpoint(file).createNewContext() + } + */ +} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 7e8098c346..78e4c57647 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -2,20 +2,19 @@ package spark.streaming import spark.streaming.StreamingContext._ -import spark.RDD -import spark.UnionRDD -import spark.Logging +import spark._ import spark.SparkContext._ import spark.storage.StorageLevel -import spark.Partitioner import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import scala.Some -abstract class DStream[T: ClassManifest] (@transient val ssc: StreamingContext) -extends Logging with Serializable { +abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) +extends Serializable with Logging { initLogging() @@ -41,10 +40,10 @@ extends Logging with Serializable { */ // Variable to store the RDDs generated earlier in time - @transient protected val generatedRDDs = new HashMap[Time, RDD[T]] () + protected val generatedRDDs = new HashMap[Time, RDD[T]] () // Variable to be set to the first time seen by the DStream (effective time zero) - protected[streaming] var zeroTime: Time = null + protected var zeroTime: Time = null // Variable to specify storage level protected var storageLevel: StorageLevel = StorageLevel.NONE @@ -53,6 +52,9 @@ extends Logging with Serializable { protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint protected var checkpointInterval: Time = null + // Reference to whole DStream graph, so that checkpointing process can lock it + protected var graph: DStreamGraph = null + // Change this RDD's storage level def persist( storageLevel: StorageLevel, @@ -77,7 +79,7 @@ extends Logging with Serializable { // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() - def isInitialized = (zeroTime != null) + def isInitialized() = (zeroTime != null) /** * This method initializes the DStream by setting the "zero" time, based on which @@ -85,15 +87,33 @@ extends Logging with Serializable { * its parent DStreams. */ protected[streaming] def initialize(time: Time) { - if (zeroTime == null) { - zeroTime = time + if (zeroTime != null) { + throw new Exception("ZeroTime is already initialized, cannot initialize it again") } + zeroTime = time logInfo(this + " initialized") dependencies.foreach(_.initialize(zeroTime)) } + protected[streaming] def setContext(s: StreamingContext) { + if (ssc != null && ssc != s) { + throw new Exception("Context is already set, cannot set it again") + } + ssc = s + logInfo("Set context for " + this.getClass.getSimpleName) + dependencies.foreach(_.setContext(ssc)) + } + + protected[streaming] def setGraph(g: DStreamGraph) { + if (graph != null && graph != g) { + throw new Exception("Graph is already set, cannot set it again") + } + graph = g + dependencies.foreach(_.setGraph(graph)) + } + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ - protected def isTimeValid (time: Time): Boolean = { + protected def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this.toString + " has not been initialized") } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { @@ -158,13 +178,42 @@ extends Logging with Serializable { } } + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + println(this.getClass().getSimpleName + ".writeObject used") + if (graph != null) { + graph.synchronized { + if (graph.checkpointInProgress) { + oos.defaultWriteObject() + } else { + val msg = "Object of " + this.getClass.getName + " is being serialized " + + " possibly as a part of closure of an RDD operation. This is because " + + " the DStream object is being referred to from within the closure. " + + " Please rewrite the RDD operation inside this DStream to avoid this. " + + " This has been enforced to avoid bloating of Spark tasks " + + " with unnecessary objects." + throw new java.io.NotSerializableException(msg) + } + } + } else { + throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.") + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + println(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + } + /** * -------------- * DStream operations * -------------- */ - - def map[U: ClassManifest](mapFunc: T => U) = new MappedDStream(this, ssc.sc.clean(mapFunc)) + def map[U: ClassManifest](mapFunc: T => U) = { + new MappedDStream(this, ssc.sc.clean(mapFunc)) + } def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) @@ -262,19 +311,15 @@ extends Logging with Serializable { // Get all the RDDs between fromTime to toTime (both included) def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() var time = toTime.floor(slideTime) - - while (time >= zeroTime && time >= fromTime) { getOrCompute(time) match { case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not get old reduced RDD for time " + time) + case None => //throw new Exception("Could not get RDD for time " + time) } time -= slideTime } - rdds.toSeq } @@ -284,12 +329,16 @@ extends Logging with Serializable { } -abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext) - extends DStream[T](ssc) { +abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) + extends DStream[T](ssc_) { override def dependencies = List() - override def slideTime = ssc.batchDuration + override def slideTime = { + if (ssc == null) throw new Exception("ssc is null") + if (ssc.batchDuration == null) throw new Exception("ssc.batchDuration is null") + ssc.batchDuration + } def start() @@ -302,7 +351,7 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc: StreamingContext) */ class MappedDStream[T: ClassManifest, U: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], mapFunc: T => U ) extends DStream[U](parent.ssc) { @@ -321,7 +370,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] ( */ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], flatMapFunc: T => Traversable[U] ) extends DStream[U](parent.ssc) { @@ -340,7 +389,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( */ class FilteredDStream[T: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], filterFunc: T => Boolean ) extends DStream[T](parent.ssc) { @@ -359,7 +408,7 @@ class FilteredDStream[T: ClassManifest]( */ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( - @transient parent: DStream[T], + parent: DStream[T], mapPartFunc: Iterator[T] => Iterator[U] ) extends DStream[U](parent.ssc) { @@ -377,7 +426,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ -class GlommedDStream[T: ClassManifest](@transient parent: DStream[T]) +class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { override def dependencies = List(parent) @@ -395,7 +444,7 @@ class GlommedDStream[T: ClassManifest](@transient parent: DStream[T]) */ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - @transient parent: DStream[(K,V)], + parent: DStream[(K,V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, @@ -420,7 +469,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( * TODO */ -class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]]) +class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) extends DStream[T](parents(0).ssc) { if (parents.length == 0) { @@ -459,7 +508,7 @@ class UnifiedDStream[T: ClassManifest](@transient parents: Array[DStream[T]]) */ class PerElementForEachDStream[T: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], foreachFunc: T => Unit ) extends DStream[Unit](parent.ssc) { @@ -490,7 +539,7 @@ class PerElementForEachDStream[T: ClassManifest] ( */ class PerRDDForEachDStream[T: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { @@ -518,15 +567,15 @@ class PerRDDForEachDStream[T: ClassManifest] ( */ class TransformedDStream[T: ClassManifest, U: ClassManifest] ( - @transient parent: DStream[T], + parent: DStream[T], transformFunc: (RDD[T], Time) => RDD[U] ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Time = parent.slideTime - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(transformFunc(_, validTime)) - } + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) } +} diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala new file mode 100644 index 0000000000..67859e0131 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -0,0 +1,80 @@ +package spark.streaming + +import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import collection.mutable.ArrayBuffer + +final class DStreamGraph extends Serializable { + + private val inputStreams = new ArrayBuffer[InputDStream[_]]() + private val outputStreams = new ArrayBuffer[DStream[_]]() + + private[streaming] var zeroTime: Time = null + private[streaming] var checkpointInProgress = false; + + def started() = (zeroTime != null) + + def start(time: Time) { + this.synchronized { + if (started) { + throw new Exception("DStream graph computation already started") + } + zeroTime = time + outputStreams.foreach(_.initialize(zeroTime)) + inputStreams.par.foreach(_.start()) + } + + } + + def stop() { + this.synchronized { + inputStreams.par.foreach(_.stop()) + } + } + + private[streaming] def setContext(ssc: StreamingContext) { + this.synchronized { + outputStreams.foreach(_.setContext(ssc)) + } + } + + def addInputStream(inputStream: InputDStream[_]) { + inputStream.setGraph(this) + inputStreams += inputStream + } + + def addOutputStream(outputStream: DStream[_]) { + outputStream.setGraph(this) + outputStreams += outputStream + } + + def getInputStreams() = inputStreams.toArray + + def getOutputStreams() = outputStreams.toArray + + def generateRDDs(time: Time): Seq[Job] = { + this.synchronized { + outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + } + } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + this.synchronized { + checkpointInProgress = true + oos.defaultWriteObject() + checkpointInProgress = false + } + println("DStreamGraph.writeObject used") + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + this.synchronized { + checkpointInProgress = true + ois.defaultReadObject() + checkpointInProgress = false + } + println("DStreamGraph.readObject used") + } +} + diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala index 96a64f0018..29ae89616e 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -1,33 +1,45 @@ package spark.streaming -import spark.SparkContext import spark.RDD -import spark.BlockRDD import spark.UnionRDD -import spark.storage.StorageLevel -import spark.streaming._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.PathFilter +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import java.io.{ObjectInputStream, IOException} class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - ssc: StreamingContext, - directory: Path, + @transient ssc_ : StreamingContext, + directory: String, filter: PathFilter = FileInputDStream.defaultPathFilter, newFilesOnly: Boolean = true) - extends InputDStream[(K, V)](ssc) { - - val fs = directory.getFileSystem(new Configuration()) + extends InputDStream[(K, V)](ssc_) { + + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + + /* + @transient @noinline lazy val path = { + //if (directory == null) throw new Exception("directory is null") + //println(directory) + new Path(directory) + } + @transient lazy val fs = path.getFileSystem(new Configuration()) + */ + var lastModTime: Long = 0 - + + def path(): Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + def fs(): FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + override def start() { if (newFilesOnly) { lastModTime = System.currentTimeMillis() @@ -58,7 +70,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } - val newFiles = fs.listStatus(directory, newFilter) + val newFiles = fs.listStatus(path, newFilter) logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) if (newFiles.length > 0) { lastModTime = newFilter.latestModTime @@ -67,10 +79,19 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) Some(newRDD) } + /* + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + println(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + println("HERE HERE" + this.directory) + } + */ + } object FileInputDStream { - val defaultPathFilter = new PathFilter { + val defaultPathFilter = new PathFilter with Serializable { def accept(path: Path): Boolean = { val file = path.getName() if (file.startsWith(".") || file.endsWith("_tmp")) { diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index b0beaba94d..e161b5ba92 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -10,9 +10,10 @@ import spark.SparkContext._ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer +import collection.SeqProxy class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( - @transient parent: DStream[(K, V)], + parent: DStream[(K, V)], reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, _windowTime: Time, @@ -46,6 +47,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( } override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val reduceF = reduceFunc + val invReduceF = invReduceFunc val currentTime = validTime val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) @@ -84,54 +87,47 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( // Cogroup the reduced RDDs and merge the reduced values val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) - val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValuesFunc) + //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - Some(mergedValuesRDD) - } - - def mergeValues(numOldValues: Int, numNewValues: Int)(seqOfValues: Seq[Seq[V]]): V = { - - if (seqOfValues.size != 1 + numOldValues + numNewValues) { - throw new Exception("Unexpected number of sequences of reduced values") - } - - // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) - - // Getting reduced values "new time steps" - val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - - if (seqOfValues(0).isEmpty) { + val numOldValues = oldRDDs.size + val numNewValues = newRDDs.size - // If previous window's reduce value does not exist, then at least new values should exist - if (newValues.isEmpty) { - throw new Exception("Neither previous window has value for key, nor new values found") + val mergeValues = (seqOfValues: Seq[Seq[V]]) => { + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + // Getting reduced values "new time steps" + val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found") + } + // Reduce the new values + newValues.reduce(reduceF) // return + } else { + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) + } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + tempValue = reduceF(tempValue, newValues.reduce(reduceF)) + } + tempValue // return + } + } - // Reduce the new values - // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) - return newValues.reduce(reduceFunc) - } else { + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - // Get the previous window's reduced value - var tempValue = seqOfValues(0).head + Some(mergedValuesRDD) + } - // If old values exists, then inverse reduce then from previous value - if (!oldValues.isEmpty) { - // println("old values = " + oldValues.map(_.toString).reduce(_ + " " + _)) - tempValue = invReduceFunc(tempValue, oldValues.reduce(reduceFunc)) - } - // If new values exists, then reduce them with previous value - if (!newValues.isEmpty) { - // println("new values = " + newValues.map(_.toString).reduce(_ + " " + _)) - tempValue = reduceFunc(tempValue, newValues.reduce(reduceFunc)) - } - // println("final value = " + tempValue) - return tempValue - } - } } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index d2e907378d..d62b7e7140 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -11,45 +11,44 @@ import scala.collection.mutable.HashMap sealed trait SchedulerMessage case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage -class Scheduler( - ssc: StreamingContext, - inputStreams: Array[InputDStream[_]], - outputStreams: Array[DStream[_]]) +class Scheduler(ssc: StreamingContext) extends Logging { initLogging() + val graph = ssc.graph val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_)) - + + def start() { - val zeroTime = Time(timer.start()) - outputStreams.foreach(_.initialize(zeroTime)) - inputStreams.par.foreach(_.start()) + if (graph.started) { + timer.restart(graph.zeroTime.milliseconds) + } else { + val zeroTime = Time(timer.start()) + graph.start(zeroTime) + } logInfo("Scheduler started") } def stop() { timer.stop() - inputStreams.par.foreach(_.stop()) + graph.stop() logInfo("Scheduler stopped") } - def generateRDDs (time: Time) { + def generateRDDs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") logInfo("Generating RDDs for time " + time) - outputStreams.foreach(outputStream => { - outputStream.generateJob(time) match { - case Some(job) => submitJob(job) - case None => - } - } - ) + graph.generateRDDs(time).foreach(submitJob) logInfo("Generated RDDs for time " + time) + if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { + ssc.checkpoint() + } } def submitJob(job: Job) { diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index c40f70c91d..d223f25dfc 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -7,6 +7,12 @@ import spark.MapPartitionsRDD import spark.SparkContext._ import spark.storage.StorageLevel + +class StateRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U], rememberPartitioner: Boolean) + extends MapPartitionsRDD[U, T](prev, f) { + override val partitioner = if (rememberPartitioner) prev.partitioner else None +} + class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( @transient parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], @@ -14,11 +20,6 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { - class SpecialMapPartitionsRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U]) - extends MapPartitionsRDD(prev, f) { - override val partitioner = if (rememberPartitioner) prev.partitioner else None - } - override def dependencies = List(parent) override def slideTime = parent.slideTime @@ -79,19 +80,18 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // first map the cogrouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val mapPartitionFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { val i = iterator.map(t => { (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) }) updateFuncLocal(i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = new SpecialMapPartitionsRDD(cogroupedRDD, mapPartitionFunc) + val stateRDD = new StateRDD(cogroupedRDD, finalFunc, rememberPartitioner) //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } case None => { // If parent RDD does not exist, then return old state RDD - //logDebug("Generating state RDD for time " + validTime + " (no change)") return Some(prevStateRDD) } } @@ -107,12 +107,12 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // first map the grouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val mapPartitionFunc = (iterator: Iterator[(K, Seq[V])]) => { + val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) } val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = new SpecialMapPartitionsRDD(groupedRDD, mapPartitionFunc) + val sessionRDD = new StateRDD(groupedRDD, finalFunc, rememberPartitioner) //logDebug("Generating state RDD for time " + validTime + " (first)") return Some(sessionRDD) } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 12f3626680..1499ef4ea2 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -21,31 +21,70 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -class StreamingContext (@transient val sc: SparkContext) extends Logging { +class StreamingContext ( + sc_ : SparkContext, + cp_ : Checkpoint + ) extends Logging { + + def this(sparkContext: SparkContext) = this(sparkContext, null) def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = - this(new SparkContext(master, frameworkName, sparkHome, jars)) + this(new SparkContext(master, frameworkName, sparkHome, jars), null) + + def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + + def this(cp_ : Checkpoint) = this(null, cp_) initLogging() + if (sc_ == null && cp_ == null) { + throw new Exception("Streaming Context cannot be initilalized with " + + "both SparkContext and checkpoint as null") + } + + val isCheckpointPresent = (cp_ != null) + + val sc: SparkContext = { + if (isCheckpointPresent) { + new SparkContext(cp_.master, cp_.frameworkName, cp_.sparkHome, cp_.jars) + } else { + sc_ + } + } + val env = SparkEnv.get - - val inputStreams = new ArrayBuffer[InputDStream[_]]() - val outputStreams = new ArrayBuffer[DStream[_]]() + + val graph: DStreamGraph = { + if (isCheckpointPresent) { + + cp_.graph.setContext(this) + cp_.graph + } else { + new DStreamGraph() + } + } + val nextNetworkInputStreamId = new AtomicInteger(0) - var batchDuration: Time = null - var scheduler: Scheduler = null + var batchDuration: Time = if (isCheckpointPresent) cp_.batchDuration else null + var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null var networkInputTracker: NetworkInputTracker = null - var receiverJobThread: Thread = null - - def setBatchDuration(duration: Long) { - setBatchDuration(Time(duration)) - } - + var receiverJobThread: Thread = null + var scheduler: Scheduler = null + def setBatchDuration(duration: Time) { + if (batchDuration != null) { + throw new Exception("Batch duration alread set as " + batchDuration + + ". cannot set it again.") + } batchDuration = duration } + + def setCheckpointDetails(file: String, interval: Time) { + checkpointFile = file + checkpointInterval = interval + } private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() @@ -59,7 +98,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { converter: (InputStream) => Iterator[T] ): DStream[T] = { val inputStream = new ObjectInputDStream[T](this, hostname, port, converter) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } @@ -69,7 +108,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_2 ): DStream[T] = { val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } @@ -94,8 +133,8 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest ](directory: String): DStream[(K, V)] = { - val inputStream = new FileInputDStream[K, V, F](this, new Path(directory)) - inputStreams += inputStream + val inputStream = new FileInputDStream[K, V, F](this, directory) + graph.addInputStream(inputStream) inputStream } @@ -113,24 +152,31 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { defaultRDD: RDD[T] = null ): DStream[T] = { val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) - inputStreams += inputStream + graph.addInputStream(inputStream) inputStream } - def createQueueStream[T: ClassManifest](iterator: Iterator[RDD[T]]): DStream[T] = { + def createQueueStream[T: ClassManifest](iterator: Array[RDD[T]]): DStream[T] = { val queue = new Queue[RDD[T]] val inputStream = createQueueStream(queue, true, null) queue ++= iterator inputStream - } + } + + /** + * This function registers a InputDStream as an input stream that will be + * started (InputDStream.start() called) to get the input data streams. + */ + def registerInputStream(inputStream: InputDStream[_]) { + graph.addInputStream(inputStream) + } - /** * This function registers a DStream as an output stream that will be * computed every interval. */ - def registerOutputStream (outputStream: DStream[_]) { - outputStreams += outputStream + def registerOutputStream(outputStream: DStream[_]) { + graph.addOutputStream(outputStream) } /** @@ -143,13 +189,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { if (batchDuration < Milliseconds(100)) { logWarning("Batch duration of " + batchDuration + " is very low") } - if (inputStreams.size == 0) { - throw new Exception("No input streams created, so nothing to take input from") - } - if (outputStreams.size == 0) { + if (graph.getOutputStreams().size == 0) { throw new Exception("No output streams registered, so nothing to execute") } - } /** @@ -157,7 +199,7 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { */ def start() { verify() - val networkInputStreams = inputStreams.filter(s => s match { + val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true case _ => false }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray @@ -169,8 +211,9 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { } Thread.sleep(1000) - // Start the scheduler - scheduler = new Scheduler(this, inputStreams.toArray, outputStreams.toArray) + + // Start the scheduler + scheduler = new Scheduler(this) scheduler.start() } @@ -189,6 +232,10 @@ class StreamingContext (@transient val sc: SparkContext) extends Logging { logInfo("StreamingContext stopped") } + + def checkpoint() { + new Checkpoint(this).saveToFile(checkpointFile) + } } diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala new file mode 100644 index 0000000000..c725035a8a --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -0,0 +1,76 @@ +package spark.streaming.examples + +import spark.streaming._ +import spark.streaming.StreamingContext._ +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +object FileStreamWithCheckpoint { + + def main(args: Array[String]) { + + if (args.size != 3) { + println("FileStreamWithCheckpoint ") + println("FileStreamWithCheckpoint restart ") + System.exit(-1) + } + + val directory = new Path(args(1)) + val checkpointFile = args(2) + + val ssc: StreamingContext = { + + if (args(0) == "restart") { + + // Recreated streaming context from specified checkpoint file + new StreamingContext(checkpointFile) + + } else { + + // Create directory if it does not exist + val fs = directory.getFileSystem(new Configuration()) + if (!fs.exists(directory)) fs.mkdirs(directory) + + // Create new streaming context + val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") + ssc_.setBatchDuration(Seconds(1)) + ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + + // Setup the streaming computation + val inputStream = ssc_.createTextFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + + ssc_ + } + } + + // Start the stream computation + startFileWritingThread(directory.toString) + ssc.start() + } + + def startFileWritingThread(directory: String) { + + val fs = new Path(directory).getFileSystem(new Configuration()) + + val fileWritingThread = new Thread() { + override def run() { + val r = new scala.util.Random() + val text = "This is a sample text file with a random number " + while(true) { + val number = r.nextInt() + val file = new Path(directory, number.toString) + val fos = fs.create(file) + fos.writeChars(text + number) + fos.close() + println("Created text file " + file) + Thread.sleep(1000) + } + } + } + fileWritingThread.start() + } + +} diff --git a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala index 7237142c7c..b1faa65c17 100644 --- a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala @@ -50,7 +50,7 @@ object Grep2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) sentences.filter(_.contains("Culpepper")).count().foreachRDD(r => println("Grep count: " + r.collect().mkString)) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index c22949d7b9..8390f4af94 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -93,7 +93,7 @@ object WordCount2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) import WordCount2_ExtraFunctions._ diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala index 3658cb302d..fc7567322b 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -50,7 +50,7 @@ object WordMax2 { println("Data count: " + data.count()) val sentences = new ConstantInputDStream(ssc, data) - ssc.inputStreams += sentences + ssc.registerInputStream(sentences) import WordCount2_ExtraFunctions._ diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index 5da9fa6ecc..7f19b26a79 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -17,12 +17,23 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } var nextTime = 0L - - def start(): Long = { - nextTime = (math.floor(clock.currentTime / period) + 1).toLong * period - thread.start() + + def start(startTime: Long): Long = { + nextTime = startTime + thread.start() nextTime } + + def start(): Long = { + val startTime = math.ceil(clock.currentTime / period).toLong * period + start(startTime) + } + + def restart(originalStartTime: Long): Long = { + val gap = clock.currentTime - originalStartTime + val newStartTime = math.ceil(gap / period).toLong * period + originalStartTime + start(newStartTime) + } def stop() { thread.interrupt() -- cgit v1.2.3 From b08708e6fcb59a09b36c5b8e3e7a4aa98f7ad050 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 13 Oct 2012 21:02:24 -0700 Subject: Fixed bugs in the streaming testsuites. --- .../test/scala/spark/streaming/DStreamBasicSuite.scala | 18 ++++++++++++------ .../test/scala/spark/streaming/DStreamSuiteBase.scala | 7 ++++++- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala index 9b953d9dae..965b58c03f 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -39,22 +39,28 @@ class DStreamBasicSuite extends DStreamSuiteBase { test("stateful operations") { val inputData = Seq( + Seq("a"), + Seq("a", "b"), Seq("a", "b", "c"), - Seq("a", "b", "c"), - Seq("a", "b", "c") + Seq("a", "b"), + Seq("a"), + Seq() ) val outputData = Seq( - Seq(("a", 1), ("b", 1), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 2)), - Seq(("a", 3), ("b", 3), ("c", 3)) + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) ) val updateStateOp = (s: DStream[String]) => { val updateFunc = (values: Seq[Int], state: RichInt) => { var newState = 0 - if (values != null) newState += values.reduce(_ + _) + if (values != null && values.size > 0) newState += values.reduce(_ + _) if (state != null) newState += state.self //println("values = " + values + ", state = " + state + ", " + " new state = " + newState) new RichInt(newState) diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala index 1c4ea14b1d..59fe36baf0 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala @@ -45,14 +45,19 @@ trait DStreamSuiteBase extends FunSuite with Logging { val clock = ssc.scheduler.clock if (clock.isInstanceOf[ManualClock]) { - clock.asInstanceOf[ManualClock].addToTime(input.size * batchDuration.milliseconds) + clock.asInstanceOf[ManualClock].addToTime((input.size - 1) * batchDuration.milliseconds) } val startTime = System.currentTimeMillis() while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + println("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) Thread.sleep(500) } + println("output.size = " + output.size) + println("output") + output.foreach(x => println("[" + x.mkString(",") + "]")) + assert(output.size === expectedOutput.size) for (i <- 0 until output.size) { if (useSet) { -- cgit v1.2.3 From 3f1aae5c71a220564adc9039dbc0e4b22aea315d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 14 Oct 2012 21:39:30 -0700 Subject: Refactored DStreamSuiteBase to create CheckpointSuite- testsuite for testing checkpointing under different operations. --- .../main/scala/spark/streaming/Checkpoint.scala | 88 +++++------ .../spark/streaming/ConstantInputDStream.scala | 4 +- .../src/main/scala/spark/streaming/DStream.scala | 4 +- .../main/scala/spark/streaming/DStreamGraph.scala | 13 +- .../src/main/scala/spark/streaming/Scheduler.scala | 20 ++- .../scala/spark/streaming/StreamingContext.scala | 18 ++- .../main/scala/spark/streaming/util/Clock.scala | 11 +- .../spark/streaming/util/RecurringTimer.scala | 4 +- .../scala/spark/streaming/CheckpointSuite.scala | 48 ++++++ .../scala/spark/streaming/DStreamBasicSuite.scala | 5 +- .../scala/spark/streaming/DStreamSuiteBase.scala | 171 ++++++++++++++++----- .../scala/spark/streaming/DStreamWindowSuite.scala | 39 ++--- 12 files changed, 290 insertions(+), 135 deletions(-) create mode 100644 streaming/src/test/scala/spark/streaming/CheckpointSuite.scala diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 3bd8fd5a27..b38911b646 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -5,11 +5,12 @@ import spark.Utils import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration -import java.io.{ObjectInputStream, ObjectOutputStream} +import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} -class Checkpoint(@transient ssc: StreamingContext) extends Serializable { + +class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable { val master = ssc.sc.master - val frameworkName = ssc.sc.frameworkName + val framework = ssc.sc.frameworkName val sparkHome = ssc.sc.sparkHome val jars = ssc.sc.jars val graph = ssc.graph @@ -17,7 +18,16 @@ class Checkpoint(@transient ssc: StreamingContext) extends Serializable { val checkpointFile = ssc.checkpointFile val checkpointInterval = ssc.checkpointInterval - def saveToFile(file: String) { + validate() + + def validate() { + assert(master != null, "Checkpoint.master is null") + assert(framework != null, "Checkpoint.framework is null") + assert(graph != null, "Checkpoint.graph is null") + assert(batchDuration != null, "Checkpoint.batchDuration is null") + } + + def saveToFile(file: String = checkpointFile) { val path = new Path(file) val conf = new Configuration() val fs = path.getFileSystem(conf) @@ -34,8 +44,7 @@ class Checkpoint(@transient ssc: StreamingContext) extends Serializable { } def toBytes(): Array[Byte] = { - val cp = new Checkpoint(ssc) - val bytes = Utils.serialize(cp) + val bytes = Utils.serialize(this) bytes } } @@ -43,50 +52,41 @@ class Checkpoint(@transient ssc: StreamingContext) extends Serializable { object Checkpoint { def loadFromFile(file: String): Checkpoint = { - val path = new Path(file) - val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (!fs.exists(path)) { - throw new Exception("Could not read checkpoint file " + path) + try { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (!fs.exists(path)) { + throw new Exception("Checkpoint file '" + file + "' does not exist") + } + val fis = fs.open(path) + val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + cp + } catch { + case e: Exception => + e.printStackTrace() + throw new Exception("Could not load checkpoint file '" + file + "'", e) } - val fis = fs.open(path) - val ois = new ObjectInputStream(fis) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp } def fromBytes(bytes: Array[Byte]): Checkpoint = { - Utils.deserialize[Checkpoint](bytes) - } - - /*def toBytes(ssc: StreamingContext): Array[Byte] = { - val cp = new Checkpoint(ssc) - val bytes = Utils.serialize(cp) - bytes + val cp = Utils.deserialize[Checkpoint](bytes) + cp.validate() + cp } +} - - def saveContext(ssc: StreamingContext, file: String) { - val cp = new Checkpoint(ssc) - val path = new Path(file) - val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (fs.exists(path)) { - val bkPath = new Path(path.getParent, path.getName + ".bk") - FileUtil.copy(fs, path, fs, bkPath, true, true, conf) - println("Moved existing checkpoint file to " + bkPath) +class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) { + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + try { + return loader.loadClass(desc.getName()) + } catch { + case e: Exception => } - val fos = fs.create(path) - val oos = new ObjectOutputStream(fos) - oos.writeObject(cp) - oos.close() - fs.close() - } - - def loadContext(file: String): StreamingContext = { - loadCheckpoint(file).createNewContext() + return super.resolveClass(desc) } - */ } diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala index 9bc204dd09..80150708fd 100644 --- a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala @@ -5,8 +5,8 @@ import spark.RDD /** * An input stream that always returns the same RDD on each timestep. Useful for testing. */ -class ConstantInputDStream[T: ClassManifest](ssc: StreamingContext, rdd: RDD[T]) - extends InputDStream[T](ssc) { +class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) + extends InputDStream[T](ssc_) { override def start() {} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 78e4c57647..0a43a042d0 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -180,7 +180,7 @@ extends Serializable with Logging { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - println(this.getClass().getSimpleName + ".writeObject used") + logDebug(this.getClass().getSimpleName + ".writeObject used") if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { @@ -202,7 +202,7 @@ extends Serializable with Logging { @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { - println(this.getClass().getSimpleName + ".readObject used") + logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() } diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index 67859e0131..bcd365e932 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -2,8 +2,10 @@ package spark.streaming import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer +import spark.Logging -final class DStreamGraph extends Serializable { +final class DStreamGraph extends Serializable with Logging { + initLogging() private val inputStreams = new ArrayBuffer[InputDStream[_]]() private val outputStreams = new ArrayBuffer[DStream[_]]() @@ -11,18 +13,15 @@ final class DStreamGraph extends Serializable { private[streaming] var zeroTime: Time = null private[streaming] var checkpointInProgress = false; - def started() = (zeroTime != null) - def start(time: Time) { this.synchronized { - if (started) { + if (zeroTime != null) { throw new Exception("DStream graph computation already started") } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) inputStreams.par.foreach(_.start()) } - } def stop() { @@ -60,21 +59,21 @@ final class DStreamGraph extends Serializable { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { this.synchronized { + logDebug("DStreamGraph.writeObject used") checkpointInProgress = true oos.defaultWriteObject() checkpointInProgress = false } - println("DStreamGraph.writeObject used") } @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { this.synchronized { + logDebug("DStreamGraph.readObject used") checkpointInProgress = true ois.defaultReadObject() checkpointInProgress = false } - println("DStreamGraph.readObject used") } } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index d62b7e7140..1e1425a88a 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -1,7 +1,6 @@ package spark.streaming -import spark.streaming.util.RecurringTimer -import spark.streaming.util.Clock +import util.{ManualClock, RecurringTimer, Clock} import spark.SparkEnv import spark.Logging @@ -23,13 +22,24 @@ extends Logging { val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_)) - def start() { - if (graph.started) { + // If context was started from checkpoint, then restart timer such that + // this timer's triggers occur at the same time as the original timer. + // Otherwise just start the timer from scratch, and initialize graph based + // on this first trigger time of the timer. + if (ssc.isCheckpointPresent) { + // If manual clock is being used for testing, then + // set manual clock to the last checkpointed time + if (clock.isInstanceOf[ManualClock]) { + val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds + clock.asInstanceOf[ManualClock].setTime(lastTime) + } timer.restart(graph.zeroTime.milliseconds) + logInfo("Scheduler's timer restarted") } else { val zeroTime = Time(timer.start()) graph.start(zeroTime) + logInfo("Scheduler's timer started") } logInfo("Scheduler started") } @@ -47,7 +57,7 @@ extends Logging { graph.generateRDDs(time).foreach(submitJob) logInfo("Generated RDDs for time " + time) if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { - ssc.checkpoint() + ssc.doCheckpoint(time) } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 1499ef4ea2..e072f15c93 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -46,7 +46,7 @@ class StreamingContext ( val sc: SparkContext = { if (isCheckpointPresent) { - new SparkContext(cp_.master, cp_.frameworkName, cp_.sparkHome, cp_.jars) + new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars) } else { sc_ } @@ -85,9 +85,13 @@ class StreamingContext ( checkpointFile = file checkpointInterval = interval } - + + private[streaming] def getInitialCheckpoint(): Checkpoint = { + if (isCheckpointPresent) cp_ else null + } + private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() - + def createNetworkTextStream(hostname: String, port: Int): DStream[String] = { createNetworkObjectStream[String](hostname, port, ObjectInputReceiver.bytesToLines) } @@ -156,10 +160,10 @@ class StreamingContext ( inputStream } - def createQueueStream[T: ClassManifest](iterator: Array[RDD[T]]): DStream[T] = { + def createQueueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = { val queue = new Queue[RDD[T]] val inputStream = createQueueStream(queue, true, null) - queue ++= iterator + queue ++= array inputStream } @@ -233,8 +237,8 @@ class StreamingContext ( logInfo("StreamingContext stopped") } - def checkpoint() { - new Checkpoint(this).saveToFile(checkpointFile) + def doCheckpoint(currentTime: Time) { + new Checkpoint(this, currentTime).saveToFile(checkpointFile) } } diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala index 72e786e0c3..ed087e4ea8 100644 --- a/streaming/src/main/scala/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/spark/streaming/util/Clock.scala @@ -56,10 +56,17 @@ class SystemClock() extends Clock { class ManualClock() extends Clock { - var time = 0L - + var time = 0L + def currentTime() = time + def setTime(timeToSet: Long) = { + this.synchronized { + time = timeToSet + this.notifyAll() + } + } + def addToTime(timeToAdd: Long) = { this.synchronized { time += timeToAdd diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index 7f19b26a79..dc55fd902b 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -25,13 +25,13 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } def start(): Long = { - val startTime = math.ceil(clock.currentTime / period).toLong * period + val startTime = (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period start(startTime) } def restart(originalStartTime: Long): Long = { val gap = clock.currentTime - originalStartTime - val newStartTime = math.ceil(gap / period).toLong * period + originalStartTime + val newStartTime = (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime start(newStartTime) } diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala new file mode 100644 index 0000000000..11cecf9822 --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -0,0 +1,48 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ + +class CheckpointSuite extends DStreamSuiteBase { + + override def framework() = "CheckpointSuite" + + override def checkpointFile() = "checkpoint" + + def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + + // Current code assumes that: + // number of inputs = number of outputs = number of batches to be run + + // Do half the computation (half the number of batches), create checkpoint file and quit + val totalNumBatches = input.size + val initialNumBatches = input.size / 2 + val nextNumBatches = totalNumBatches - initialNumBatches + val initialNumExpectedOutputs = initialNumBatches + + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet) + Thread.sleep(1000) + + // Restart and complete the computation from checkpoint file + val sscNew = new StreamingContext(checkpointFile) + sscNew.setCheckpointDetails(null, null) + val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size) + verifyOutput[V](outputNew, expectedOutput, useSet) + } + + test("simple per-batch operation") { + testCheckpointedOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + true + ) + } +} \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala index 965b58c03f..f8ca7febe7 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -22,9 +22,9 @@ class DStreamBasicSuite extends DStreamSuiteBase { test("shuffle-based operations") { // reduceByKey testOperation( - Seq(Seq("a", "a", "b"), Seq("", ""), Seq()), + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), true ) @@ -62,7 +62,6 @@ class DStreamBasicSuite extends DStreamSuiteBase { var newState = 0 if (values != null && values.size > 0) newState += values.reduce(_ + _) if (state != null) newState += state.self - //println("values = " + values + ", state = " + state + ", " + " new state = " + newState) new RichInt(newState) } s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala index 59fe36baf0..cb95c36782 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala @@ -4,70 +4,157 @@ import spark.{RDD, Logging} import util.ManualClock import collection.mutable.ArrayBuffer import org.scalatest.FunSuite -import scala.collection.mutable.Queue +import collection.mutable.SynchronizedBuffer +class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, val input: Seq[Seq[T]]) + extends InputDStream[T](ssc_) { + var currentIndex = 0 + + def start() {} + + def stop() {} + + def compute(validTime: Time): Option[RDD[T]] = { + logInfo("Computing RDD for time " + validTime) + val rdd = if (currentIndex < input.size) { + ssc.sc.makeRDD(input(currentIndex), 2) + } else { + ssc.sc.makeRDD(Seq[T](), 2) + } + logInfo("Created RDD " + rdd.id) + currentIndex += 1 + Some(rdd) + } +} + +class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) + extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + }) trait DStreamSuiteBase extends FunSuite with Logging { - def batchDuration() = Seconds(1) + def framework() = "DStreamSuiteBase" - def maxWaitTimeMillis() = 10000 + def master() = "local[2]" - def testOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - useSet: Boolean = false - ) { + def batchDuration() = Seconds(1) - val manualClock = true + def checkpointFile() = null.asInstanceOf[String] - if (manualClock) { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - } + def checkpointInterval() = batchDuration - val ssc = new StreamingContext("local", "test") + def maxWaitTimeMillis() = 10000 - try { - ssc.setBatchDuration(Milliseconds(batchDuration)) + def setupStreams[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V] + ): StreamingContext = { + + // Create StreamingContext + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + if (checkpointFile != null) { + ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + } - val inputQueue = new Queue[RDD[U]]() - inputQueue ++= input.map(ssc.sc.makeRDD(_, 2)) - val emptyRDD = ssc.sc.makeRDD(Seq[U](), 2) + // Setup the stream computation + val inputStream = new TestInputStream(ssc, input) + ssc.registerInputStream(inputStream) + val operatedStream = operation(inputStream) + val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) + ssc.registerOutputStream(outputStream) + ssc + } - val inputStream = ssc.createQueueStream(inputQueue, true, emptyRDD) - val outputStream = operation(inputStream) + def runStreams[V: ClassManifest]( + ssc: StreamingContext, + numBatches: Int, + numExpectedOutput: Int + ): Seq[Seq[V]] = { + logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) - val output = new ArrayBuffer[Seq[V]]() - outputStream.foreachRDD(rdd => output += rdd.collect()) + // Get the output buffer + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + val output = outputStream.output + try { + // Start computation ssc.start() - val clock = ssc.scheduler.clock - if (clock.isInstanceOf[ManualClock]) { - clock.asInstanceOf[ManualClock].addToTime((input.size - 1) * batchDuration.milliseconds) - } + // Advance manual clock + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logInfo("Manual clock before advancing = " + clock.time) + clock.addToTime(numBatches * batchDuration.milliseconds) + logInfo("Manual clock after advancing = " + clock.time) + // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() - while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - println("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) - Thread.sleep(500) + while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) + Thread.sleep(100) } + val timeTaken = System.currentTimeMillis() - startTime - println("output.size = " + output.size) - println("output") - output.foreach(x => println("[" + x.mkString(",") + "]")) - - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - if (useSet) { - assert(output(i).toSet === expectedOutput(i).toSet) - } else { - assert(output(i).toList === expectedOutput(i).toList) - } - } + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") + } catch { + case e: Exception => e.printStackTrace(); throw e; } finally { ssc.stop() } + + output + } + + def verifyOutput[V: ClassManifest]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + useSet: Boolean + ) { + logInfo("--------------------------------") + logInfo("output.size = " + output.size) + logInfo("output") + output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Match the output with the expected output + assert(output.size === expectedOutput.size, "Number of outputs do not match") + for (i <- 0 until output.size) { + if (useSet) { + assert(output(i).toSet === expectedOutput(i).toSet) + } else { + assert(output(i).toList === expectedOutput(i).toList) + } + } + logInfo("Output verified successfully") + } + + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { + testOperation[U, V](input, operation, expectedOutput, -1, useSet) + } + + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatches: Int, + useSet: Boolean + ) { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + + val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, numBatches_, expectedOutput.size) + verifyOutput[V](output, expectedOutput, useSet) } } diff --git a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala index 061cab2cbb..8dd18f491a 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala @@ -4,6 +4,10 @@ import spark.streaming.StreamingContext._ class DStreamWindowSuite extends DStreamSuiteBase { + override def framework() = "DStreamWindowSuite" + + override def maxWaitTimeMillis() = 20000 + val largerSlideInput = Seq( Seq(("a", 1)), // 1st window from here Seq(("a", 2)), @@ -81,16 +85,15 @@ class DStreamWindowSuite extends DStreamSuiteBase { name: String, input: Seq[Seq[(String, Int)]], expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = Seconds(2), - slideTime: Time = Seconds(1) + windowTime: Time = batchDuration * 2, + slideTime: Time = batchDuration ) { test("reduceByKeyAndWindow - " + name) { - testOperation( - input, - (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist(), - expectedOutput, - true - ) + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) } } @@ -98,16 +101,15 @@ class DStreamWindowSuite extends DStreamSuiteBase { name: String, input: Seq[Seq[(String, Int)]], expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = Seconds(2), - slideTime: Time = Seconds(1) + windowTime: Time = batchDuration * 2, + slideTime: Time = batchDuration ) { test("reduceByKeyAndWindowInv - " + name) { - testOperation( - input, - (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist(), - expectedOutput, - true - ) + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) } } @@ -116,8 +118,8 @@ class DStreamWindowSuite extends DStreamSuiteBase { testReduceByKeyAndWindow( "basic reduction", - Seq(Seq(("a", 1), ("a", 3)) ), - Seq(Seq(("a", 4)) ) + Seq( Seq(("a", 1), ("a", 3)) ), + Seq( Seq(("a", 4)) ) ) testReduceByKeyAndWindow( @@ -126,7 +128,6 @@ class DStreamWindowSuite extends DStreamSuiteBase { Seq( Seq(("a", 1)), Seq(("a", 2)) ) ) - testReduceByKeyAndWindow( "new key added into window", Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), -- cgit v1.2.3 From b760d6426a7fa2a6d115cefc786aa766b9419bd6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 15 Oct 2012 12:26:44 -0700 Subject: Minor modifications. --- .../src/main/scala/spark/streaming/Checkpoint.scala | 6 ++++++ .../main/scala/spark/streaming/StreamingContext.scala | 19 +++++++------------ .../test/scala/spark/streaming/DStreamSuiteBase.scala | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index b38911b646..f7936bdc5f 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -25,6 +25,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext assert(framework != null, "Checkpoint.framework is null") assert(graph != null, "Checkpoint.graph is null") assert(batchDuration != null, "Checkpoint.batchDuration is null") + assert(checkpointTime != null, "Checkpoint.checkpointTime is null") } def saveToFile(file: String = checkpointFile) { @@ -60,6 +61,11 @@ object Checkpoint { throw new Exception("Checkpoint file '" + file + "' does not exist") } val fis = fs.open(path) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) val cp = ois.readObject.asInstanceOf[Checkpoint] ois.close() diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index e072f15c93..62d21b83d9 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -184,25 +184,20 @@ class StreamingContext ( } /** - * This function verify whether the stream computation is eligible to be executed. + * This function validate whether the stream computation is eligible to be executed. */ - private def verify() { - if (batchDuration == null) { - throw new Exception("Batch duration has not been set") - } - if (batchDuration < Milliseconds(100)) { - logWarning("Batch duration of " + batchDuration + " is very low") - } - if (graph.getOutputStreams().size == 0) { - throw new Exception("No output streams registered, so nothing to execute") - } + private def validate() { + assert(batchDuration != null, "Batch duration has not been set") + assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") + assert(graph != null, "Graph is null") + assert(graph.getOutputStreams().size > 0, "No output streams registered, so nothing to execute") } /** * This function starts the execution of the streams. */ def start() { - verify() + validate() val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true case _ => false diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala index cb95c36782..91ffc0c098 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala @@ -61,9 +61,9 @@ trait DStreamSuiteBase extends FunSuite with Logging { // Setup the stream computation val inputStream = new TestInputStream(ssc, input) - ssc.registerInputStream(inputStream) val operatedStream = operation(inputStream) val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) + ssc.registerInputStream(inputStream) ssc.registerOutputStream(outputStream) ssc } -- cgit v1.2.3 From 4a3fb06ac2d11125feb08acbbd4df76d1e91b677 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 16 Oct 2012 01:10:01 -0700 Subject: Updated Kryo to 2.20. --- core/src/main/scala/spark/KryoSerializer.scala | 205 ++++++++----------------- project/SparkBuild.scala | 2 +- 2 files changed, 69 insertions(+), 138 deletions(-) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 44b630e478..f24196ea49 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -9,153 +9,80 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} -import com.esotericsoftware.kryo.serialize.ClassSerializer -import com.esotericsoftware.kryo.serialize.SerializableSerializer +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import de.javakaffee.kryoserializers.KryoReflectionFactorySupport import serializer.{SerializerInstance, DeserializationStream, SerializationStream} import spark.broadcast._ import spark.storage._ -/** - * Zig-zag encoder used to write object sizes to serialization streams. - * Based on Kryo's integer encoder. - */ -private[spark] object ZigZag { - def writeInt(n: Int, out: OutputStream) { - var value = n - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - out.write(value) - } +private[spark] +class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - def readInt(in: InputStream): Int = { - var offset = 0 - var result = 0 - while (offset < 32) { - val b = in.read() - if (b == -1) { - throw new EOFException("End of stream") - } - result |= ((b & 0x7F) << offset) - if ((b & 0x80) == 0) { - return result - } - offset += 7 - } - throw new SparkException("Malformed zigzag-encoded integer") - } -} - -private[spark] -class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream) -extends SerializationStream { - val channel = Channels.newChannel(out) + val output = new KryoOutput(outStream) def writeObject[T](t: T): SerializationStream = { - kryo.writeClassAndObject(threadBuffer, t) - ZigZag.writeInt(threadBuffer.position(), out) - threadBuffer.flip() - channel.write(threadBuffer) - threadBuffer.clear() + kryo.writeClassAndObject(output, t) this } - def flush() { out.flush() } - def close() { out.close() } + def flush() { output.flush() } + def close() { output.close() } } -private[spark] -class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream) -extends DeserializationStream { +private[spark] +class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { + + val input = new KryoInput(inStream) + def readObject[T](): T = { - val len = ZigZag.readInt(in) - objectBuffer.readClassAndObject(in, len).asInstanceOf[T] + try { + kryo.readClassAndObject(input).asInstanceOf[T] + } catch { + // DeserializationStream uses the EOF exception to indicate stopping condition. + case e: com.esotericsoftware.kryo.KryoException => throw new java.io.EOFException + } } - def close() { in.close() } + def close() { + input.close() + inStream.close() + } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.kryo - val threadBuffer = ks.threadBuffer.get() - val objectBuffer = ks.objectBuffer.get() + + val kryo = ks.kryo.get() + val output = ks.output.get() + val input = ks.input.get() def serialize[T](t: T): ByteBuffer = { - // Write it to our thread-local scratch buffer first to figure out the size, then return a new - // ByteBuffer of the appropriate size - threadBuffer.clear() - kryo.writeClassAndObject(threadBuffer, t) - val newBuf = ByteBuffer.allocate(threadBuffer.position) - threadBuffer.flip() - newBuf.put(threadBuffer) - newBuf.flip() - newBuf + output.clear() + kryo.writeClassAndObject(output, t) + ByteBuffer.wrap(output.toBytes) } def deserialize[T](bytes: ByteBuffer): T = { - kryo.readClassAndObject(bytes).asInstanceOf[T] + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] } def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { val oldClassLoader = kryo.getClassLoader kryo.setClassLoader(loader) - val obj = kryo.readClassAndObject(bytes).asInstanceOf[T] + input.setBuffer(bytes.array) + val obj = kryo.readClassAndObject(input).asInstanceOf[T] kryo.setClassLoader(oldClassLoader) obj } def serializeStream(s: OutputStream): SerializationStream = { - threadBuffer.clear() - new KryoSerializationStream(kryo, threadBuffer, s) + new KryoSerializationStream(kryo, s) } def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(objectBuffer, s) - } - - override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { - threadBuffer.clear() - while (iterator.hasNext) { - val element = iterator.next() - // TODO: Do we also want to write the object's size? Doesn't seem necessary. - kryo.writeClassAndObject(threadBuffer, element) - } - val newBuf = ByteBuffer.allocate(threadBuffer.position) - threadBuffer.flip() - newBuf.put(threadBuffer) - newBuf.flip() - newBuf - } - - override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - buffer.rewind() - new Iterator[Any] { - override def hasNext: Boolean = buffer.remaining > 0 - override def next(): Any = kryo.readClassAndObject(buffer) - } + new KryoDeserializationStream(kryo, s) } } @@ -171,18 +98,19 @@ trait KryoRegistrator { * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. */ class KryoSerializer extends spark.serializer.Serializer with Logging { - // Make this lazy so that it only gets called once we receive our first task on each executor, - // so we can pull out any custom Kryo registrator from the user's JARs. - lazy val kryo = createKryo() - val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024 + val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 - val objectBuffer = new ThreadLocal[ObjectBuffer] { - override def initialValue = new ObjectBuffer(kryo, bufferSize) + val kryo = new ThreadLocal[Kryo] { + override def initialValue = createKryo() } - val threadBuffer = new ThreadLocal[ByteBuffer] { - override def initialValue = ByteBuffer.allocate(bufferSize) + val output = new ThreadLocal[KryoOutput] { + override def initialValue = new KryoOutput(bufferSize) + } + + val input = new ThreadLocal[KryoInput] { + override def initialValue = new KryoInput(bufferSize) } def createKryo(): Kryo = { @@ -213,41 +141,44 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { kryo.register(obj.getClass) } - // Register the following classes for passing closures. - kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) - kryo.setRegistrationOptional(true) - // Allow sending SerializableWritable - kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer()) + kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. - class SingletonSerializer(obj: AnyRef) extends KSerializer { - override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {} - override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = obj.asInstanceOf[T] + class SingletonSerializer[T](obj: T) extends KSerializer[T] { + override def write(kryo: Kryo, output: KryoOutput, obj: T) {} + override def read(kryo: Kryo, input: KryoInput, cls: java.lang.Class[T]): T = obj } - kryo.register(None.getClass, new SingletonSerializer(None)) - kryo.register(Nil.getClass, new SingletonSerializer(Nil)) + kryo.register(None.getClass, new SingletonSerializer[AnyRef](None)) + kryo.register(Nil.getClass, new SingletonSerializer[AnyRef](Nil)) // Register maps with a special serializer since they have complex internal structure class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any]) - extends KSerializer { - override def writeObjectData(buf: ByteBuffer, obj: AnyRef) { + extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] { + override def write( + kryo: Kryo, + output: KryoOutput, + obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) { val map = obj.asInstanceOf[scala.collection.Map[Any, Any]] - kryo.writeObject(buf, map.size.asInstanceOf[java.lang.Integer]) + kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer]) for ((k, v) <- map) { - kryo.writeClassAndObject(buf, k) - kryo.writeClassAndObject(buf, v) + kryo.writeClassAndObject(output, k) + kryo.writeClassAndObject(output, v) } } - override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = { - val size = kryo.readObject(buf, classOf[java.lang.Integer]).intValue + override def read ( + kryo: Kryo, + input: KryoInput, + cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]]) + : Array[(Any, Any)] => scala.collection.Map[Any, Any] = { + val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue val elems = new Array[(Any, Any)](size) for (i <- 0 until size) - elems(i) = (kryo.readClassAndObject(buf), kryo.readClassAndObject(buf)) - buildMap(elems).asInstanceOf[T] + elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input)) + buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]] } } kryo.register(mutable.HashMap().getClass, new ScalaMapSerializer(mutable.HashMap() ++ _)) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c9cf17d90a..1023019d24 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -120,7 +120,7 @@ object SparkBuild extends Build { "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION, "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", - "de.javakaffee" % "kryo-serializers" % "0.9", + "de.javakaffee" % "kryo-serializers" % "0.20", "com.typesafe.akka" % "akka-actor" % "2.0.3", "com.typesafe.akka" % "akka-remote" % "2.0.3", "com.typesafe.akka" % "akka-slf4j" % "2.0.3", -- cgit v1.2.3 From 52989c8a2c8c10d7f5610c033f6782e58fd3abc2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 19 Oct 2012 10:24:49 -0700 Subject: Update Python API for v0.6.0 compatibility. --- core/src/main/scala/spark/api/python/PythonRDD.scala | 18 +++++++++++------- core/src/main/scala/spark/broadcast/Broadcast.scala | 2 +- pyspark/pyspark/broadcast.py | 18 +++++++++--------- pyspark/pyspark/context.py | 2 +- pyspark/pyspark/java_gateway.py | 3 ++- pyspark/pyspark/serializers.py | 18 ++++++++++++++---- pyspark/pyspark/worker.py | 8 ++++---- 7 files changed, 42 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 4d3bdb3963..528885fe5c 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -5,11 +5,15 @@ import java.io._ import scala.collection.Map import scala.collection.JavaConversions._ import scala.io.Source -import spark._ -import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import broadcast.Broadcast -import scala.collection -import java.nio.charset.Charset + +import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import spark.broadcast.Broadcast +import spark.SparkEnv +import spark.Split +import spark.RDD +import spark.OneToOneDependency +import spark.rdd.PipedRDD + trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], @@ -43,9 +47,9 @@ trait PythonRDDBase { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream) - out.println(broadcastVars.length) + dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { - out.print(broadcast.uuid.toString) + dOut.writeLong(broadcast.id) dOut.writeInt(broadcast.value.length) dOut.write(broadcast.value) dOut.flush() diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 6055bfd045..2ffe7f741d 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong import spark._ -abstract class Broadcast[T](id: Long) extends Serializable { +abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { def value: T // We cannot have an abstract readObject here due to some weird issues with diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py index 1ea17d59af..4cff02b36d 100644 --- a/pyspark/pyspark/broadcast.py +++ b/pyspark/pyspark/broadcast.py @@ -6,7 +6,7 @@ [1, 2, 3, 4, 5] >>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.uuid] = b +>>> _broadcastRegistry[b.bid] = b >>> from cPickle import dumps, loads >>> loads(dumps(b)).value [1, 2, 3, 4, 5] @@ -14,27 +14,27 @@ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] """ -# Holds broadcasted data received from Java, keyed by UUID. +# Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} -def _from_uuid(uuid): +def _from_id(bid): from pyspark.broadcast import _broadcastRegistry - if uuid not in _broadcastRegistry: - raise Exception("Broadcast variable '%s' not loaded!" % uuid) - return _broadcastRegistry[uuid] + if bid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % bid) + return _broadcastRegistry[bid] class Broadcast(object): - def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None): + def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): self.value = value - self.uuid = uuid + self.bid = bid self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry def __reduce__(self): self._pickle_registry.add(self) - return (_from_uuid, (self.uuid, )) + return (_from_id, (self.bid, )) def _test(): diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 04932c93f2..3f4db26644 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -66,5 +66,5 @@ class SparkContext(object): def broadcast(self, value): jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) - return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, + return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index bcb405ba72..3726bcbf17 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -7,7 +7,8 @@ SPARK_HOME = os.environ["SPARK_HOME"] assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \ - "/spark-core-assembly-*-SNAPSHOT.jar")[0] + "/spark-core-assembly-*.jar")[0] + # TODO: what if multiple assembly jars are found? def launch_gateway(): diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index faa1e683c7..21ef8b106c 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -9,16 +9,26 @@ def dump_pickle(obj): load_pickle = cPickle.loads +def read_long(stream): + length = stream.read(8) + if length == "": + raise EOFError + return struct.unpack("!q", length)[0] + + +def read_int(stream): + length = stream.read(4) + if length == "": + raise EOFError + return struct.unpack("!i", length)[0] + def write_with_length(obj, stream): stream.write(struct.pack("!i", len(obj))) stream.write(obj) def read_with_length(stream): - length = stream.read(4) - if length == "": - raise EOFError - length = struct.unpack("!i", length)[0] + length = read_int(stream) obj = stream.read(length) if obj == "": raise EOFError diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index a9ed71892f..62824a1c9b 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -8,7 +8,7 @@ from base64 import standard_b64decode from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import write_with_length, read_with_length, \ - dump_pickle, load_pickle + read_long, read_int, dump_pickle, load_pickle # Redirect stdout to stderr so that users must return values from functions. @@ -29,11 +29,11 @@ def read_input(): def main(): - num_broadcast_variables = int(sys.stdin.readline().strip()) + num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): - uuid = sys.stdin.read(36) + bid = read_long(sys.stdin) value = read_with_length(sys.stdin) - _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value)) + _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) func = load_obj() bypassSerializer = load_obj() if bypassSerializer: -- cgit v1.2.3 From 6d5eb4b40ccad150c967fee8557a4e5d5664b4bd Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 19 Oct 2012 12:11:44 -0700 Subject: Added functionality to forget RDDs from DStreams. --- .../main/scala/spark/streaming/Checkpoint.scala | 2 - .../src/main/scala/spark/streaming/DStream.scala | 90 +++++++++++++++------- .../main/scala/spark/streaming/DStreamGraph.scala | 40 ++++++++-- .../spark/streaming/PairDStreamFunctions.scala | 9 ++- .../spark/streaming/ReducedWindowedDStream.scala | 14 +++- .../src/main/scala/spark/streaming/Scheduler.scala | 5 +- .../scala/spark/streaming/StreamingContext.scala | 34 +++----- .../scala/spark/streaming/WindowedDStream.scala | 6 +- .../scala/spark/streaming/CheckpointSuite.scala | 9 ++- .../scala/spark/streaming/DStreamBasicSuite.scala | 87 ++++++++++++++++++--- .../scala/spark/streaming/DStreamSuiteBase.scala | 7 +- 11 files changed, 224 insertions(+), 79 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index f7936bdc5f..23fd0f2434 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -14,7 +14,6 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext val sparkHome = ssc.sc.sparkHome val jars = ssc.sc.jars val graph = ssc.graph - val batchDuration = ssc.batchDuration val checkpointFile = ssc.checkpointFile val checkpointInterval = ssc.checkpointInterval @@ -24,7 +23,6 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext assert(master != null, "Checkpoint.master is null") assert(framework != null, "Checkpoint.framework is null") assert(graph != null, "Checkpoint.graph is null") - assert(batchDuration != null, "Checkpoint.batchDuration is null") assert(checkpointTime != null, "Checkpoint.checkpointTime is null") } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 0a43a042d0..645636b603 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -39,22 +39,30 @@ extends Serializable with Logging { * --------------------------------------- */ - // Variable to store the RDDs generated earlier in time - protected val generatedRDDs = new HashMap[Time, RDD[T]] () + // RDDs generated, marked as protected[streaming] so that testsuites can access it + protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] () - // Variable to be set to the first time seen by the DStream (effective time zero) + // Time zero for the DStream protected var zeroTime: Time = null - // Variable to specify storage level + // Time after which RDDs will be forgotten + protected var forgetTime: Time = null + + // Storage level of the RDDs in the stream protected var storageLevel: StorageLevel = StorageLevel.NONE // Checkpoint level and checkpoint interval protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint protected var checkpointInterval: Time = null - // Reference to whole DStream graph, so that checkpointing process can lock it + // Reference to whole DStream graph protected var graph: DStreamGraph = null + def isInitialized = (zeroTime != null) + + // Time gap for forgetting old RDDs (i.e. removing them from generatedRDDs) + def parentForgetTime = forgetTime + // Change this RDD's storage level def persist( storageLevel: StorageLevel, @@ -79,8 +87,6 @@ extends Serializable with Logging { // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() - def isInitialized() = (zeroTime != null) - /** * This method initializes the DStream by setting the "zero" time, based on which * the validity of future times is calculated. This method also recursively initializes @@ -91,31 +97,43 @@ extends Serializable with Logging { throw new Exception("ZeroTime is already initialized, cannot initialize it again") } zeroTime = time - logInfo(this + " initialized") dependencies.foreach(_.initialize(zeroTime)) + logInfo("Initialized " + this) } protected[streaming] def setContext(s: StreamingContext) { if (ssc != null && ssc != s) { - throw new Exception("Context is already set, cannot set it again") + throw new Exception("Context is already set in " + this + ", cannot set it again") } ssc = s - logInfo("Set context for " + this.getClass.getSimpleName) + logInfo("Set context for " + this) dependencies.foreach(_.setContext(ssc)) } protected[streaming] def setGraph(g: DStreamGraph) { if (graph != null && graph != g) { - throw new Exception("Graph is already set, cannot set it again") + throw new Exception("Graph is already set in " + this + ", cannot set it again") } graph = g dependencies.foreach(_.setGraph(graph)) } + protected[streaming] def setForgetTime(time: Time = slideTime) { + if (time == null) { + throw new Exception("Time gap for forgetting RDDs cannot be set to null for " + this) + } else if (forgetTime != null && time < forgetTime) { + throw new Exception("Time gap for forgetting RDDs cannot be reduced from " + forgetTime + + " to " + time + " for " + this) + } + forgetTime = time + dependencies.foreach(_.setForgetTime(parentForgetTime)) + logInfo("Time gap for forgetting RDDs set to " + forgetTime + " for " + this) + } + /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ protected def isTimeValid(time: Time): Boolean = { if (!isInitialized) { - throw new Exception (this.toString + " has not been initialized") + throw new Exception (this + " has not been initialized") } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { false } else { @@ -178,6 +196,21 @@ extends Serializable with Logging { } } + def forgetOldRDDs(time: Time) { + val keys = generatedRDDs.keys + var numForgotten = 0 + + keys.foreach(t => { + if (t < (time - forgetTime)) { + generatedRDDs.remove(t) + numForgotten += 1 + //logInfo("Forgot RDD of time " + t + " from " + this) + } + }) + logInfo("Forgot " + numForgotten + " RDDs from " + this) + dependencies.foreach(_.forgetOldRDDs(time)) + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { logDebug(this.getClass().getSimpleName + ".writeObject used") @@ -257,7 +290,7 @@ extends Serializable with Logging { new TransformedDStream(this, ssc.sc.clean(transformFunc)) } - def toBlockingQueue = { + def toBlockingQueue() = { val queue = new ArrayBlockingQueue[RDD[T]](10000) this.foreachRDD(rdd => { queue.add(rdd) @@ -265,7 +298,7 @@ extends Serializable with Logging { queue } - def print() = { + def print() { def foreachFunc = (rdd: RDD[T], time: Time) => { val first11 = rdd.take(11) println ("-------------------------------------------") @@ -277,33 +310,38 @@ extends Serializable with Logging { } val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) ssc.registerOutputStream(newStream) - newStream } - def window(windowTime: Time, slideTime: Time) = new WindowedDStream(this, windowTime, slideTime) + def window(windowTime: Time): DStream[T] = window(windowTime, this.slideTime) + + def window(windowTime: Time, slideTime: Time): DStream[T] = { + new WindowedDStream(this, windowTime, slideTime) + } - def batch(batchTime: Time) = window(batchTime, batchTime) + def tumble(batchTime: Time): DStream[T] = window(batchTime, batchTime) - def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time) = + def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time): DStream[T] = { this.window(windowTime, slideTime).reduce(reduceFunc) + } def reduceByWindow( - reduceFunc: (T, T) => T, - invReduceFunc: (T, T) => T, - windowTime: Time, - slideTime: Time) = { + reduceFunc: (T, T) => T, + invReduceFunc: (T, T) => T, + windowTime: Time, + slideTime: Time + ): DStream[T] = { this.map(x => (1, x)) .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1) .map(_._2) } - def countByWindow(windowTime: Time, slideTime: Time) = { + def countByWindow(windowTime: Time, slideTime: Time): DStream[Int] = { def add(v1: Int, v2: Int) = (v1 + v2) def subtract(v1: Int, v2: Int) = (v1 - v2) this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime) } - def union(that: DStream[T]) = new UnifiedDStream(Array(this, that)) + def union(that: DStream[T]): DStream[T] = new UnifiedDStream[T](Array(this, that)) def slice(interval: Interval): Seq[RDD[T]] = { slice(interval.beginTime, interval.endTime) @@ -336,8 +374,8 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContex override def slideTime = { if (ssc == null) throw new Exception("ssc is null") - if (ssc.batchDuration == null) throw new Exception("ssc.batchDuration is null") - ssc.batchDuration + if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") + ssc.graph.batchDuration } def start() diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index bcd365e932..964c8a26a0 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -11,7 +11,8 @@ final class DStreamGraph extends Serializable with Logging { private val outputStreams = new ArrayBuffer[DStream[_]]() private[streaming] var zeroTime: Time = null - private[streaming] var checkpointInProgress = false; + private[streaming] var batchDuration: Time = null + private[streaming] var checkpointInProgress = false def start(time: Time) { this.synchronized { @@ -20,6 +21,7 @@ final class DStreamGraph extends Serializable with Logging { } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) + outputStreams.foreach(_.setForgetTime()) inputStreams.par.foreach(_.start()) } } @@ -36,14 +38,28 @@ final class DStreamGraph extends Serializable with Logging { } } + def setBatchDuration(duration: Time) { + this.synchronized { + if (batchDuration != null) { + throw new Exception("Batch duration already set as " + batchDuration + + ". cannot set it again.") + } + } + batchDuration = duration + } + def addInputStream(inputStream: InputDStream[_]) { - inputStream.setGraph(this) - inputStreams += inputStream + this.synchronized { + inputStream.setGraph(this) + inputStreams += inputStream + } } def addOutputStream(outputStream: DStream[_]) { - outputStream.setGraph(this) - outputStreams += outputStream + this.synchronized { + outputStream.setGraph(this) + outputStreams += outputStream + } } def getInputStreams() = inputStreams.toArray @@ -56,6 +72,20 @@ final class DStreamGraph extends Serializable with Logging { } } + def forgetOldRDDs(time: Time) { + this.synchronized { + outputStreams.foreach(_.forgetOldRDDs(time)) + } + } + + def validate() { + this.synchronized { + assert(batchDuration != null, "Batch duration has not been set") + assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") + assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") + } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { this.synchronized { diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 3fd0a16bf0..0bd0321928 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -75,6 +75,13 @@ extends Serializable { stream.window(windowTime, slideTime).groupByKey(partitioner) } + def reduceByKeyAndWindow( + reduceFunc: (V, V) => V, + windowTime: Time + ): DStream[(K, V)] = { + reduceByKeyAndWindow(reduceFunc, windowTime, stream.slideTime, defaultPartitioner()) + } + def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowTime: Time, @@ -106,7 +113,7 @@ extends Serializable { // so that new elements introduced in the window can be "added" using // reduceFunc to the previous window's result and old elements can be // "subtracted using invReduceFunc. - + def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index e161b5ba92..f3e95c9e2b 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -31,12 +31,15 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner) - override def dependencies = List(reducedStream) - def windowTime: Time = _windowTime + override def dependencies = List(reducedStream) + override def slideTime: Time = _slideTime + //TODO: This is wrong. This should depend on the checkpointInterval + override def parentForgetTime: Time = forgetTime + windowTime + override def persist( storageLevel: StorageLevel, checkpointLevel: StorageLevel, @@ -46,6 +49,13 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( this } + protected[streaming] override def setForgetTime(time: Time) { + if (forgetTime == null || forgetTime < time) { + forgetTime = time + dependencies.foreach(_.setForgetTime(forgetTime + windowTime)) + } + } + override def compute(validTime: Time): Option[RDD[(K, V)]] = { val reduceF = reduceFunc val invReduceF = invReduceFunc diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 1e1425a88a..99e30b6110 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -20,7 +20,7 @@ extends Logging { val jobManager = new JobManager(ssc, concurrentJobs) val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] - val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_)) + val timer = new RecurringTimer(clock, ssc.graph.batchDuration, generateRDDs(_)) def start() { // If context was started from checkpoint, then restart timer such that @@ -53,11 +53,12 @@ extends Logging { def generateRDDs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") - logInfo("Generating RDDs for time " + time) graph.generateRDDs(time).foreach(submitJob) logInfo("Generated RDDs for time " + time) + graph.forgetOldRDDs(time) if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { ssc.doCheckpoint(time) + logInfo("Checkpointed at time " + time) } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 62d21b83d9..b5f4571798 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -6,16 +6,11 @@ import spark.SparkEnv import spark.SparkContext import spark.storage.StorageLevel -import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Queue import java.io.InputStream -import java.io.IOException -import java.net.InetSocketAddress import java.util.concurrent.atomic.AtomicInteger -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -65,20 +60,15 @@ class StreamingContext ( } val nextNetworkInputStreamId = new AtomicInteger(0) - - var batchDuration: Time = if (isCheckpointPresent) cp_.batchDuration else null - var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null - var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null var networkInputTracker: NetworkInputTracker = null - var receiverJobThread: Thread = null - var scheduler: Scheduler = null + + private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null + private[streaming] var receiverJobThread: Thread = null + private[streaming] var scheduler: Scheduler = null def setBatchDuration(duration: Time) { - if (batchDuration != null) { - throw new Exception("Batch duration alread set as " + batchDuration + - ". cannot set it again.") - } - batchDuration = duration + graph.setBatchDuration(duration) } def setCheckpointDetails(file: String, interval: Time) { @@ -183,21 +173,17 @@ class StreamingContext ( graph.addOutputStream(outputStream) } - /** - * This function validate whether the stream computation is eligible to be executed. - */ - private def validate() { - assert(batchDuration != null, "Batch duration has not been set") - assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") + def validate() { assert(graph != null, "Graph is null") - assert(graph.getOutputStreams().size > 0, "No output streams registered, so nothing to execute") + graph.validate() } - + /** * This function starts the execution of the streams. */ def start() { validate() + val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true case _ => false diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala index 93c1291691..2984f88284 100644 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -18,12 +18,14 @@ class WindowedDStream[T: ClassManifest]( throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - override def dependencies = List(parent) - def windowTime: Time = _windowTime + override def dependencies = List(parent) + override def slideTime: Time = _slideTime + override def parentForgetTime: Time = forgetTime + windowTime + override def compute(validTime: Time): Option[RDD[T]] = { val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 11cecf9822..061b331a16 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -1,6 +1,7 @@ package spark.streaming import spark.streaming.StreamingContext._ +import java.io.File class CheckpointSuite extends DStreamSuiteBase { @@ -14,17 +15,16 @@ class CheckpointSuite extends DStreamSuiteBase { expectedOutput: Seq[Seq[V]], useSet: Boolean = false ) { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") // Current code assumes that: // number of inputs = number of outputs = number of batches to be run - // Do half the computation (half the number of batches), create checkpoint file and quit val totalNumBatches = input.size val initialNumBatches = input.size / 2 val nextNumBatches = totalNumBatches - initialNumBatches val initialNumExpectedOutputs = initialNumBatches + // Do half the computation (half the number of batches), create checkpoint file and quit val ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet) @@ -35,6 +35,11 @@ class CheckpointSuite extends DStreamSuiteBase { sscNew.setCheckpointDetails(null, null) val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size) verifyOutput[V](outputNew, expectedOutput, useSet) + + new File(checkpointFile).delete() + new File(checkpointFile + ".bk").delete() + new File("." + checkpointFile + ".crc").delete() + new File("." + checkpointFile + ".bk.crc").delete() } test("simple per-batch operation") { diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala index f8ca7febe7..5dd8b675b1 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -2,16 +2,21 @@ package spark.streaming import spark.streaming.StreamingContext._ import scala.runtime.RichInt +import util.ManualClock class DStreamBasicSuite extends DStreamSuiteBase { - test("map-like operations") { + test("map") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.map(_.toString), + input.map(_.map(_.toString)) + ) + } + + test("flatmap") { val input = Seq(1 to 4, 5 to 8, 9 to 12) - - // map - testOperation(input, (r: DStream[Int]) => r.map(_.toString), input.map(_.map(_.toString))) - - // flatMap testOperation( input, (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)), @@ -19,16 +24,16 @@ class DStreamBasicSuite extends DStreamSuiteBase { ) } - test("shuffle-based operations") { - // reduceByKey + test("reduceByKey") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), true ) + } - // reduce + test("reduce") { testOperation( Seq(1 to 4, 5 to 8, 9 to 12), (s: DStream[Int]) => s.reduce(_ + _), @@ -57,7 +62,7 @@ class DStreamBasicSuite extends DStreamSuiteBase { Seq(("a", 5), ("b", 3), ("c", 1)) ) - val updateStateOp = (s: DStream[String]) => { + val updateStateOperation = (s: DStream[String]) => { val updateFunc = (values: Seq[Int], state: RichInt) => { var newState = 0 if (values != null && values.size > 0) newState += values.reduce(_ + _) @@ -67,6 +72,66 @@ class DStreamBasicSuite extends DStreamSuiteBase { s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) } - testOperation(inputData, updateStateOp, outputData, true) + testOperation(inputData, updateStateOperation, outputData, true) + } + + test("forgetting of RDDs") { + assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") + + val input = Seq(1 to 4, 5 to 8, 9 to 12, 13 to 16, 17 to 20, 21 to 24, 25 to 28, 29 to 32) + + assert(input.size % 4 === 0, "Number of inputs should be a multiple of 4") + + def operation(s: DStream[Int]): DStream[(Int, Int)] = { + s.map(x => (x % 10, 1)) + .window(Seconds(2), Seconds(1)) + .reduceByKeyAndWindow(_ + _, _ - _, Seconds(4), Seconds(1)) + } + + val ssc = setupStreams(input, operation _) + runStreams[(Int, Int)](ssc, input.size, input.size) + + val reducedWindowedStream = ssc.graph.getOutputStreams().head.dependencies.head + .asInstanceOf[ReducedWindowedDStream[Int, Int]] + val windowedStream = reducedWindowedStream.dependencies.head.dependencies.head + .asInstanceOf[WindowedDStream[(Int, Int)]] + val mappedStream = windowedStream.dependencies.head + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val finalTime = Seconds(7) + //assert(clock.time === finalTime.milliseconds) + + // ReducedWindowedStream should remember the last RDD created + assert(reducedWindowedStream.generatedRDDs.contains(finalTime)) + + // ReducedWindowedStream should have forgotten the previous to last RDD created + assert(!reducedWindowedStream.generatedRDDs.contains(finalTime - reducedWindowedStream.slideTime)) + + // WindowedStream should remember the last RDD created + assert(windowedStream.generatedRDDs.contains(finalTime)) + + // WindowedStream should still remember the previous to last RDD created + // as the last RDD of ReducedWindowedStream requires that RDD + assert(windowedStream.generatedRDDs.contains(finalTime - windowedStream.slideTime)) + + // WindowedStream should have forgotten this RDD as the last RDD of + // ReducedWindowedStream DOES NOT require this RDD + assert(!windowedStream.generatedRDDs.contains(finalTime - windowedStream.slideTime - reducedWindowedStream.windowTime)) + + // MappedStream should remember the last RDD created + assert(mappedStream.generatedRDDs.contains(finalTime)) + + // MappedStream should still remember the previous to last RDD created + // as the last RDD of WindowedStream requires that RDD + assert(mappedStream.generatedRDDs.contains(finalTime - mappedStream.slideTime)) + + // MappedStream should still remember this RDD as the last RDD of + // ReducedWindowedStream requires that RDD (even though the last RDD of + // WindowedStream does not need it) + assert(mappedStream.generatedRDDs.contains(finalTime - windowedStream.windowTime)) + + // MappedStream should have forgotten this RDD as the last RDD of + // ReducedWindowedStream DOES NOT require this RDD + assert(!mappedStream.generatedRDDs.contains(finalTime - mappedStream.slideTime - windowedStream.windowTime - reducedWindowedStream.windowTime)) } } diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala index 91ffc0c098..6e5a7a58bb 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala @@ -35,6 +35,8 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu trait DStreamSuiteBase extends FunSuite with Logging { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + def framework() = "DStreamSuiteBase" def master() = "local[2]" @@ -73,6 +75,9 @@ trait DStreamSuiteBase extends FunSuite with Logging { numBatches: Int, numExpectedOutput: Int ): Seq[Seq[V]] = { + + assert(numBatches > 0, "Number of batches to run stream computation is zero") + assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) // Get the output buffer @@ -150,8 +155,6 @@ trait DStreamSuiteBase extends FunSuite with Logging { numBatches: Int, useSet: Boolean ) { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size val ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, numBatches_, expectedOutput.size) -- cgit v1.2.3 From c23bf1aff4b9a1faf9d32c7b64acad2213f9515c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 20 Oct 2012 00:16:41 +0000 Subject: Add PySpark README and run scripts. --- core/src/main/scala/spark/SparkContext.scala | 2 +- pyspark/README | 58 ++++++++++++++++++++++++++++ pyspark/pyspark-shell | 3 ++ pyspark/pyspark/context.py | 5 +-- pyspark/pyspark/examples/wordcount.py | 17 ++++++++ pyspark/pyspark/shell.py | 21 ++++++++++ pyspark/run-pyspark | 23 +++++++++++ 7 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 pyspark/README create mode 100755 pyspark/pyspark-shell create mode 100644 pyspark/pyspark/examples/wordcount.py create mode 100644 pyspark/pyspark/shell.py create mode 100755 pyspark/run-pyspark diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index becf737597..acb38ae33d 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -113,7 +113,7 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING")) { + "SPARK_TESTING", "PYTHONPATH")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value diff --git a/pyspark/README b/pyspark/README new file mode 100644 index 0000000000..63a1def141 --- /dev/null +++ b/pyspark/README @@ -0,0 +1,58 @@ +# PySpark + +PySpark is a Python API for Spark. + +PySpark jobs are writen in Python and executed using a standard Python +interpreter; this supports modules that use Python C extensions. The +API is based on the Spark Scala API and uses regular Python functions +and lambdas to support user-defined functions. PySpark supports +interactive use through a standard Python interpreter; it can +automatically serialize closures and ship them to worker processes. + +PySpark is built on top of the Spark Java API. Data is uniformly +represented as serialized Python objects and stored in Spark Java +processes, which communicate with PySpark worker processes over pipes. + +## Features + +PySpark supports most of the Spark API, including broadcast variables. +RDDs are dynamically typed and can hold any Python object. + +PySpark does not support: + +- Special functions on RDDs of doubles +- Accumulators + +## Examples and Documentation + +The PySpark source contains docstrings and doctests that document its +API. The public classes are in `context.py` and `rdd.py`. + +The `pyspark/pyspark/examples` directory contains a few complete +examples. + +## Installing PySpark + +PySpark requires a development version of Py4J, a Python library for +interacting with Java processes. It can be installed from +https://github.com/bartdag/py4j; make sure to install a version that +contains at least the commits through 3dbf380d3d. + +PySpark uses the `PYTHONPATH` environment variable to search for Python +classes; Py4J should be on this path, along with any libraries used by +PySpark programs. `PYTHONPATH` will be automatically shipped to worker +machines, but the files that it points to must be present on each +machine. + +PySpark requires the Spark assembly JAR, which can be created by running +`sbt/sbt assembly` in the Spark directory. + +Additionally, `SPARK_HOME` should be set to the location of the Spark +package. + +## Running PySpark + +The easiest way to run PySpark is to use the `run-pyspark` and +`pyspark-shell` scripts, which are included in the `pyspark` directory. +These scripts automatically load the `spark-conf.sh` file, set +`SPARK_HOME`, and add the `pyspark` package to the `PYTHONPATH`. diff --git a/pyspark/pyspark-shell b/pyspark/pyspark-shell new file mode 100755 index 0000000000..4ed3e6010c --- /dev/null +++ b/pyspark/pyspark-shell @@ -0,0 +1,3 @@ +#!/bin/sh +FWDIR="`dirname $0`" +exec $FWDIR/run-pyspark $FWDIR/pyspark/shell.py "$@" diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 3f4db26644..50d57e5317 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -18,14 +18,13 @@ class SparkContext(object): asPickle = jvm.spark.api.python.PythonRDD.asPickle arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultParallelism=None, - pythonExec='python'): + def __init__(self, master, name, defaultParallelism=None): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() - self.pythonExec = pythonExec + self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to diff --git a/pyspark/pyspark/examples/wordcount.py b/pyspark/pyspark/examples/wordcount.py new file mode 100644 index 0000000000..8365c070e8 --- /dev/null +++ b/pyspark/pyspark/examples/wordcount.py @@ -0,0 +1,17 @@ +import sys +from operator import add +from pyspark.context import SparkContext + +if __name__ == "__main__": + if len(sys.argv) < 3: + print >> sys.stderr, \ + "Usage: PythonWordCount " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonWordCount") + lines = sc.textFile(sys.argv[2], 1) + counts = lines.flatMap(lambda x: x.split(' ')) \ + .map(lambda x: (x, 1)) \ + .reduceByKey(add) + output = counts.collect() + for (word, count) in output: + print "%s : %i" % (word, count) diff --git a/pyspark/pyspark/shell.py b/pyspark/pyspark/shell.py new file mode 100644 index 0000000000..7ef30894cb --- /dev/null +++ b/pyspark/pyspark/shell.py @@ -0,0 +1,21 @@ +""" +An interactive shell. +""" +import code +import sys + +from pyspark.context import SparkContext + + +def main(master='local'): + sc = SparkContext(master, 'PySparkShell') + print "Spark context available as sc." + code.interact(local={'sc': sc}) + + +if __name__ == '__main__': + if len(sys.argv) > 1: + master = sys.argv[1] + else: + master = 'local' + main(master) diff --git a/pyspark/run-pyspark b/pyspark/run-pyspark new file mode 100755 index 0000000000..9c5e027962 --- /dev/null +++ b/pyspark/run-pyspark @@ -0,0 +1,23 @@ +#!/bin/bash + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; cd ../; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" + +# Load environment variables from conf/spark-env.sh, if it exists +if [ -e $FWDIR/conf/spark-env.sh ] ; then + . $FWDIR/conf/spark-env.sh +fi + +# Figure out which Python executable to use +if [ -z "$PYSPARK_PYTHON" ] ; then + PYSPARK_PYTHON="python" +fi +export PYSPARK_PYTHON + +# Add the PySpark classes to the Python path: +export PYTHONPATH=$SPARK_HOME/pyspark/:$PYTHONPATH + +exec "$PYSPARK_PYTHON" "$@" -- cgit v1.2.3 From c4a2b6f636040bacd3d4b443e65cc33dafd0aa7e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 21 Oct 2012 10:41:25 -0700 Subject: Fixed some bugs in tests for forgetting RDDs, and made sure that use of manual clock leads to a zeroTime of 0 in the DStreams (more intuitive). --- .../main/scala/spark/streaming/Checkpoint.scala | 2 +- .../src/main/scala/spark/streaming/DStream.scala | 40 +++++------ .../main/scala/spark/streaming/DStreamGraph.scala | 17 ++++- .../spark/streaming/ReducedWindowedDStream.scala | 10 +-- .../src/main/scala/spark/streaming/Scheduler.scala | 4 +- .../scala/spark/streaming/StreamingContext.scala | 4 ++ .../scala/spark/streaming/WindowedDStream.scala | 2 +- .../scala/spark/streaming/DStreamBasicSuite.scala | 83 ++++++++++------------ .../scala/spark/streaming/DStreamSuiteBase.scala | 2 + .../scala/spark/streaming/DStreamWindowSuite.scala | 25 ++++--- 10 files changed, 100 insertions(+), 89 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 23fd0f2434..ebff9bdb51 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -33,7 +33,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext if (fs.exists(path)) { val bkPath = new Path(path.getParent, path.getName + ".bk") FileUtil.copy(fs, path, fs, bkPath, true, true, conf) - println("Moved existing checkpoint file to " + bkPath) + //logInfo("Moved existing checkpoint file to " + bkPath) } val fos = fs.create(path) val oos = new ObjectOutputStream(fos) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 645636b603..f6cd135e59 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -30,7 +30,7 @@ extends Serializable with Logging { // List of parent DStreams on which this DStream depends on def dependencies: List[DStream[_]] - // Key method that computes RDD for a valid time + // Key method that computes RDD for a valid time def compute (validTime: Time): Option[RDD[T]] /** @@ -45,8 +45,8 @@ extends Serializable with Logging { // Time zero for the DStream protected var zeroTime: Time = null - // Time after which RDDs will be forgotten - protected var forgetTime: Time = null + // Duration for which the DStream will remember each RDD created + protected var rememberDuration: Time = null // Storage level of the RDDs in the stream protected var storageLevel: StorageLevel = StorageLevel.NONE @@ -60,8 +60,8 @@ extends Serializable with Logging { def isInitialized = (zeroTime != null) - // Time gap for forgetting old RDDs (i.e. removing them from generatedRDDs) - def parentForgetTime = forgetTime + // Duration for which the DStream requires its parent DStream to remember each RDD created + def parentRememberDuration = rememberDuration // Change this RDD's storage level def persist( @@ -118,23 +118,24 @@ extends Serializable with Logging { dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def setForgetTime(time: Time = slideTime) { - if (time == null) { - throw new Exception("Time gap for forgetting RDDs cannot be set to null for " + this) - } else if (forgetTime != null && time < forgetTime) { - throw new Exception("Time gap for forgetting RDDs cannot be reduced from " + forgetTime - + " to " + time + " for " + this) + protected[streaming] def setRememberDuration(duration: Time = slideTime) { + if (duration == null) { + throw new Exception("Duration for remembering RDDs cannot be set to null for " + this) + } else if (rememberDuration != null && duration < rememberDuration) { + logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration + + " to " + duration + " for " + this) + } else { + rememberDuration = duration + dependencies.foreach(_.setRememberDuration(parentRememberDuration)) + logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) } - forgetTime = time - dependencies.foreach(_.setForgetTime(parentForgetTime)) - logInfo("Time gap for forgetting RDDs set to " + forgetTime + " for " + this) } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ protected def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this + " has not been initialized") - } else if (time < zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { + } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { false } else { true @@ -143,7 +144,7 @@ extends Serializable with Logging { /** * This method either retrieves a precomputed RDD of this DStream, - * or computes the RDD (if the time is valid) + * or computes the RDD (if the time is valid) */ def getOrCompute(time: Time): Option[RDD[T]] = { // If this DStream was not initialized (i.e., zeroTime not set), then do it @@ -154,7 +155,7 @@ extends Serializable with Logging { // probably all RDDs in this DStream will be reused and hence should be cached case Some(oldRDD) => Some(oldRDD) - // if RDD was not generated, and if the time is valid + // if RDD was not generated, and if the time is valid // (based on sliding time of this DStream), then generate the RDD case None => { if (isTimeValid(time)) { @@ -199,9 +200,8 @@ extends Serializable with Logging { def forgetOldRDDs(time: Time) { val keys = generatedRDDs.keys var numForgotten = 0 - keys.foreach(t => { - if (t < (time - forgetTime)) { + if (t <= (time - rememberDuration)) { generatedRDDs.remove(t) numForgotten += 1 //logInfo("Forgot RDD of time " + t + " from " + this) @@ -530,7 +530,7 @@ class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) val rdds = new ArrayBuffer[RDD[T]]() parents.map(_.getOrCompute(validTime)).foreach(_ match { case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) + case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) }) if (rdds.size > 0) { Some(new UnionRDD(ssc.sc, rdds)) diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index 964c8a26a0..ac44d7a2a6 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -12,6 +12,7 @@ final class DStreamGraph extends Serializable with Logging { private[streaming] var zeroTime: Time = null private[streaming] var batchDuration: Time = null + private[streaming] var rememberDuration: Time = null private[streaming] var checkpointInProgress = false def start(time: Time) { @@ -21,7 +22,11 @@ final class DStreamGraph extends Serializable with Logging { } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) - outputStreams.foreach(_.setForgetTime()) + outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values + if (rememberDuration != null) { + // if custom rememberDuration has been provided, set the rememberDuration + outputStreams.foreach(_.setRememberDuration(rememberDuration)) + } inputStreams.par.foreach(_.start()) } } @@ -48,6 +53,16 @@ final class DStreamGraph extends Serializable with Logging { batchDuration = duration } + def setRememberDuration(duration: Time) { + this.synchronized { + if (rememberDuration != null) { + throw new Exception("Batch duration already set as " + batchDuration + + ". cannot set it again.") + } + } + rememberDuration = duration + } + def addInputStream(inputStream: InputDStream[_]) { this.synchronized { inputStream.setGraph(this) diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index f3e95c9e2b..fcf57aced7 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -38,7 +38,7 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( override def slideTime: Time = _slideTime //TODO: This is wrong. This should depend on the checkpointInterval - override def parentForgetTime: Time = forgetTime + windowTime + override def parentRememberDuration: Time = rememberDuration + windowTime override def persist( storageLevel: StorageLevel, @@ -49,10 +49,10 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( this } - protected[streaming] override def setForgetTime(time: Time) { - if (forgetTime == null || forgetTime < time) { - forgetTime = time - dependencies.foreach(_.setForgetTime(forgetTime + windowTime)) + protected[streaming] override def setRememberDuration(time: Time) { + if (rememberDuration == null || rememberDuration < time) { + rememberDuration = time + dependencies.foreach(_.setRememberDuration(rememberDuration + windowTime)) } } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 99e30b6110..7d52e2eddf 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -37,8 +37,8 @@ extends Logging { timer.restart(graph.zeroTime.milliseconds) logInfo("Scheduler's timer restarted") } else { - val zeroTime = Time(timer.start()) - graph.start(zeroTime) + val firstTime = Time(timer.start()) + graph.start(firstTime - ssc.graph.batchDuration) logInfo("Scheduler's timer started") } logInfo("Scheduler started") diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index b5f4571798..7022056f7c 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -71,6 +71,10 @@ class StreamingContext ( graph.setBatchDuration(duration) } + def setRememberDuration(duration: Time) { + graph.setRememberDuration(duration) + } + def setCheckpointDetails(file: String, interval: Time) { checkpointFile = file checkpointInterval = interval diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala index 2984f88284..b90e22351b 100644 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -24,7 +24,7 @@ class WindowedDStream[T: ClassManifest]( override def slideTime: Time = _slideTime - override def parentForgetTime: Time = forgetTime + windowTime + override def parentRememberDuration: Time = rememberDuration + windowTime override def compute(validTime: Time): Option[RDD[T]] = { val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala index 5dd8b675b1..28bbb152ca 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -41,7 +41,7 @@ class DStreamBasicSuite extends DStreamSuiteBase { ) } - test("stateful operations") { + test("updateStateByKey") { val inputData = Seq( Seq("a"), @@ -75,63 +75,54 @@ class DStreamBasicSuite extends DStreamSuiteBase { testOperation(inputData, updateStateOperation, outputData, true) } - test("forgetting of RDDs") { + test("forgetting of RDDs - map and window operations") { assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") - val input = Seq(1 to 4, 5 to 8, 9 to 12, 13 to 16, 17 to 20, 21 to 24, 25 to 28, 29 to 32) + val input = (0 until 10).map(x => Seq(x, x + 1)).toSeq + val rememberDuration = Seconds(3) - assert(input.size % 4 === 0, "Number of inputs should be a multiple of 4") + assert(input.size === 10, "Number of inputs have changed") def operation(s: DStream[Int]): DStream[(Int, Int)] = { s.map(x => (x % 10, 1)) .window(Seconds(2), Seconds(1)) - .reduceByKeyAndWindow(_ + _, _ - _, Seconds(4), Seconds(1)) + .window(Seconds(4), Seconds(2)) } val ssc = setupStreams(input, operation _) - runStreams[(Int, Int)](ssc, input.size, input.size) + ssc.setRememberDuration(rememberDuration) + runStreams[(Int, Int)](ssc, input.size, input.size / 2) - val reducedWindowedStream = ssc.graph.getOutputStreams().head.dependencies.head - .asInstanceOf[ReducedWindowedDStream[Int, Int]] - val windowedStream = reducedWindowedStream.dependencies.head.dependencies.head - .asInstanceOf[WindowedDStream[(Int, Int)]] - val mappedStream = windowedStream.dependencies.head + val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head + val windowedStream1 = windowedStream2.dependencies.head + val mappedStream = windowedStream1.dependencies.head val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val finalTime = Seconds(7) - //assert(clock.time === finalTime.milliseconds) - - // ReducedWindowedStream should remember the last RDD created - assert(reducedWindowedStream.generatedRDDs.contains(finalTime)) - - // ReducedWindowedStream should have forgotten the previous to last RDD created - assert(!reducedWindowedStream.generatedRDDs.contains(finalTime - reducedWindowedStream.slideTime)) - - // WindowedStream should remember the last RDD created - assert(windowedStream.generatedRDDs.contains(finalTime)) - - // WindowedStream should still remember the previous to last RDD created - // as the last RDD of ReducedWindowedStream requires that RDD - assert(windowedStream.generatedRDDs.contains(finalTime - windowedStream.slideTime)) - - // WindowedStream should have forgotten this RDD as the last RDD of - // ReducedWindowedStream DOES NOT require this RDD - assert(!windowedStream.generatedRDDs.contains(finalTime - windowedStream.slideTime - reducedWindowedStream.windowTime)) - - // MappedStream should remember the last RDD created - assert(mappedStream.generatedRDDs.contains(finalTime)) - - // MappedStream should still remember the previous to last RDD created - // as the last RDD of WindowedStream requires that RDD - assert(mappedStream.generatedRDDs.contains(finalTime - mappedStream.slideTime)) - - // MappedStream should still remember this RDD as the last RDD of - // ReducedWindowedStream requires that RDD (even though the last RDD of - // WindowedStream does not need it) - assert(mappedStream.generatedRDDs.contains(finalTime - windowedStream.windowTime)) - - // MappedStream should have forgotten this RDD as the last RDD of - // ReducedWindowedStream DOES NOT require this RDD - assert(!mappedStream.generatedRDDs.contains(finalTime - mappedStream.slideTime - windowedStream.windowTime - reducedWindowedStream.windowTime)) + assert(clock.time === Seconds(10).milliseconds) + + // IDEALLY + // WindowedStream2 should remember till 7 seconds: 10, 8, + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5 + // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, + + // IN THIS TEST + // WindowedStream2 should remember till 7 seconds: 10, 8, + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 + // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 + + // WindowedStream2 + assert(windowedStream2.generatedRDDs.contains(Seconds(10))) + assert(windowedStream2.generatedRDDs.contains(Seconds(8))) + assert(!windowedStream2.generatedRDDs.contains(Seconds(6))) + + // WindowedStream1 + assert(windowedStream1.generatedRDDs.contains(Seconds(10))) + assert(windowedStream1.generatedRDDs.contains(Seconds(4))) + assert(!windowedStream1.generatedRDDs.contains(Seconds(3))) + + // MappedStream + assert(mappedStream.generatedRDDs.contains(Seconds(10))) + assert(mappedStream.generatedRDDs.contains(Seconds(2))) + assert(!mappedStream.generatedRDDs.contains(Seconds(1))) } } diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala index 6e5a7a58bb..59fa5a6f22 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala @@ -104,6 +104,8 @@ trait DStreamSuiteBase extends FunSuite with Logging { assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") + + Thread.sleep(500) // Give some time for the forgetting old RDDs to complete } catch { case e: Exception => e.printStackTrace(); throw e; } finally { diff --git a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala index 8dd18f491a..cfcab6298d 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala @@ -9,25 +9,24 @@ class DStreamWindowSuite extends DStreamSuiteBase { override def maxWaitTimeMillis() = 20000 val largerSlideInput = Seq( - Seq(("a", 1)), // 1st window from here - Seq(("a", 2)), - Seq(("a", 3)), // 2nd window from here - Seq(("a", 4)), - Seq(("a", 5)), // 3rd window from here - Seq(("a", 6)), - Seq(), // 4th window from here + Seq(("a", 1)), + Seq(("a", 2)), // 1st window from here + Seq(("a", 3)), + Seq(("a", 4)), // 2nd window from here + Seq(("a", 5)), + Seq(("a", 6)), // 3rd window from here Seq(), - Seq() // 5th window from here + Seq() // 4th window from here ) val largerSlideOutput = Seq( - Seq(("a", 1)), - Seq(("a", 6)), - Seq(("a", 14)), - Seq(("a", 15)), - Seq(("a", 6)) + Seq(("a", 3)), + Seq(("a", 10)), + Seq(("a", 18)), + Seq(("a", 11)) ) + val bigInput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)), -- cgit v1.2.3 From d85c66636ba3b5d32f7e3b47c5b68e1064f8f588 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 21 Oct 2012 17:40:08 -0700 Subject: Added MapValueDStream, FlatMappedValuesDStream and CoGroupedDStream, and therefore DStream operations mapValue, flatMapValues, cogroup, and join. Also, added tests for DStream operations filter, glom, mapPartitions, groupByKey, mapValues, flatMapValues, cogroup, and join. --- .../scala/spark/streaming/CoGroupedDStream.scala | 37 ++++++++++ .../src/main/scala/spark/streaming/DStream.scala | 63 +++++++++++++--- .../spark/streaming/PairDStreamFunctions.scala | 75 +++++++++++++++---- .../scala/spark/streaming/examples/CountRaw.scala | 2 +- .../scala/spark/streaming/examples/GrepRaw.scala | 2 +- .../streaming/examples/TopKWordCountRaw.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 2 +- .../scala/spark/streaming/DStreamBasicSuite.scala | 86 ++++++++++++++++++++++ .../scala/spark/streaming/DStreamSuiteBase.scala | 59 ++++++++++++++- 9 files changed, 293 insertions(+), 35 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala diff --git a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala new file mode 100644 index 0000000000..5522e2ee21 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala @@ -0,0 +1,37 @@ +package spark.streaming + +import spark.{CoGroupedRDD, RDD, Partitioner} + +class CoGroupedDStream[K : ClassManifest]( + parents: Seq[DStream[(_, _)]], + partitioner: Partitioner + ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime = parents.head.slideTime + + override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { + val part = partitioner + val rdds = parents.flatMap(_.getOrCompute(validTime)) + if (rdds.size > 0) { + val q = new CoGroupedRDD[K](rdds, part) + Some(q) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index f6cd135e59..38bb7c8b94 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -244,27 +244,27 @@ extends Serializable with Logging { * DStream operations * -------------- */ - def map[U: ClassManifest](mapFunc: T => U) = { + def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { new MappedDStream(this, ssc.sc.clean(mapFunc)) } - def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]) = { + def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) } - def filter(filterFunc: T => Boolean) = new FilteredDStream(this, filterFunc) + def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) - def glom() = new GlommedDStream(this) + def glom(): DStream[Array[T]] = new GlommedDStream(this) - def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]) = { + def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]): DStream[U] = { new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc)) } - def reduce(reduceFunc: (T, T) => T) = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + def reduce(reduceFunc: (T, T) => T): DStream[T] = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) - def count() = this.map(_ => 1).reduce(_ + _) + def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _) - def collect() = this.map(x => (1, x)).groupByKey(1).map(_._2) + def collect(): DStream[Seq[T]] = this.map(x => (null, x)).groupByKey(1).map(_._2) def foreach(foreachFunc: T => Unit) { val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc)) @@ -341,7 +341,7 @@ extends Serializable with Logging { this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime) } - def union(that: DStream[T]): DStream[T] = new UnifiedDStream[T](Array(this, that)) + def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) def slice(interval: Interval): Seq[RDD[T]] = { slice(interval.beginTime, interval.endTime) @@ -507,8 +507,47 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( * TODO */ -class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) - extends DStream[T](parents(0).ssc) { +class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + mapValueFunc: V => U + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) + } +} + + +/** + * TODO + */ + +class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + flatMapValueFunc: V => TraversableOnce[U] + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) + } +} + + + +/** + * TODO + */ + +class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) + extends DStream[T](parents.head.ssc) { if (parents.length == 0) { throw new IllegalArgumentException("Empty array of parents") @@ -524,7 +563,7 @@ class UnifiedDStream[T: ClassManifest](parents: Array[DStream[T]]) override def dependencies = parents.toList - override def slideTime: Time = parents(0).slideTime + override def slideTime: Time = parents.head.slideTime override def compute(validTime: Time): Option[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 0bd0321928..5de57eb2fd 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -1,17 +1,16 @@ package spark.streaming import scala.collection.mutable.ArrayBuffer -import spark.Partitioner -import spark.HashPartitioner +import spark.{Manifests, RDD, Partitioner, HashPartitioner} import spark.streaming.StreamingContext._ import javax.annotation.Nullable -class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) +class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)]) extends Serializable { - def ssc = stream.ssc + def ssc = self.ssc - def defaultPartitioner(numPartitions: Int = stream.ssc.sc.defaultParallelism) = { + def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { new HashPartitioner(numPartitions) } @@ -28,10 +27,10 @@ extends Serializable { } def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { - def createCombiner(v: V) = ArrayBuffer[V](v) - def mergeValue(c: ArrayBuffer[V], v: V) = (c += v) - def mergeCombiner(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = (c1 ++ c2) - combineByKey(createCombiner _, mergeValue _, mergeCombiner _, partitioner).asInstanceOf[DStream[(K, Seq[V])]] + val createCombiner = (v: V) => ArrayBuffer[V](v) + val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) + val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) + combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner).asInstanceOf[DStream[(K, Seq[V])]] } def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { @@ -52,7 +51,7 @@ extends Serializable { mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, partitioner: Partitioner) : ShuffledDStream[K, V, C] = { - new ShuffledDStream[K, V, C](stream, createCombiner, mergeValue, mergeCombiner, partitioner) + new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner) } def groupByKeyAndWindow(windowTime: Time, slideTime: Time): DStream[(K, Seq[V])] = { @@ -72,14 +71,14 @@ extends Serializable { slideTime: Time, partitioner: Partitioner ): DStream[(K, Seq[V])] = { - stream.window(windowTime, slideTime).groupByKey(partitioner) + self.window(windowTime, slideTime).groupByKey(partitioner) } def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowTime: Time ): DStream[(K, V)] = { - reduceByKeyAndWindow(reduceFunc, windowTime, stream.slideTime, defaultPartitioner()) + reduceByKeyAndWindow(reduceFunc, windowTime, self.slideTime, defaultPartitioner()) } def reduceByKeyAndWindow( @@ -105,7 +104,7 @@ extends Serializable { slideTime: Time, partitioner: Partitioner ): DStream[(K, V)] = { - stream.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner) + self.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner) } // This method is the efficient sliding window reduce operation, @@ -148,7 +147,7 @@ extends Serializable { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) new ReducedWindowedDStream[K, V]( - stream, cleanedReduceFunc, cleanedInvReduceFunc, windowTime, slideTime, partitioner) + self, cleanedReduceFunc, cleanedInvReduceFunc, windowTime, slideTime, partitioner) } // TODO: @@ -184,7 +183,53 @@ extends Serializable { partitioner: Partitioner, rememberPartitioner: Boolean ): DStream[(K, S)] = { - new StateDStream(stream, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner) + new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner) + } + + + def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = { + new MapValuesDStream[K, V, U](self, mapValuesFunc) + } + + def flatMapValues[U: ClassManifest]( + flatMapValuesFunc: V => TraversableOnce[U] + ): DStream[(K, U)] = { + new FlatMapValuesDStream[K, V, U](self, flatMapValuesFunc) + } + + def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { + cogroup(other, defaultPartitioner()) + } + + def cogroup[W: ClassManifest]( + other: DStream[(K, W)], + partitioner: Partitioner + ): DStream[(K, (Seq[V], Seq[W]))] = { + + val cgd = new CoGroupedDStream[K]( + Seq(self.asInstanceOf[DStream[(_, _)]], other.asInstanceOf[DStream[(_, _)]]), + partitioner + ) + val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)( + classManifest[K], + Manifests.seqSeqManifest + ) + pdfs.mapValues { + case Seq(vs, ws) => + (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) + } + } + + def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = { + join[W](other, defaultPartitioner()) + } + + def join[W: ClassManifest](other: DStream[(K, W)], partitioner: Partitioner): DStream[(K, (V, W))] = { + this.cogroup(other, partitioner) + .flatMapValues{ + case (vs, ws) => + for (v <- vs.iterator; w <- ws.iterator) yield (v, w) + } } } diff --git a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala index c78c1e9660..ed571d22e3 100644 --- a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala @@ -25,7 +25,7 @@ object CountRaw { val rawStreams = (1 to numStreams).map(_ => ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray - val union = new UnifiedDStream(rawStreams) + val union = new UnionDStream(rawStreams) union.map(_.length + 2).reduce(_ + _).foreachRDD(r => println("Byte count: " + r.collect().mkString)) ssc.start() } diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index cc52da7bd4..6af1c36891 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -25,7 +25,7 @@ object GrepRaw { val rawStreams = (1 to numStreams).map(_ => ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray - val union = new UnifiedDStream(rawStreams) + val union = new UnionDStream(rawStreams) union.filter(_.contains("Culpepper")).count().foreachRDD(r => println("Grep count: " + r.collect().mkString)) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 3ba07d0448..af0a3bf98a 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -37,7 +37,7 @@ object TopKWordCountRaw { val rawStreams = (1 to streams).map(_ => ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray - val union = new UnifiedDStream(rawStreams) + val union = new UnionDStream(rawStreams) val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 9702003805..98bafec529 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -37,7 +37,7 @@ object WordCountRaw { val rawStreams = (1 to streams).map(_ => ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray - val union = new UnifiedDStream(rawStreams) + val union = new UnionDStream(rawStreams) val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala index 28bbb152ca..db95c2cfaa 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -24,6 +24,44 @@ class DStreamBasicSuite extends DStreamSuiteBase { ) } + test("filter") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.filter(x => (x % 2 == 0)), + input.map(_.filter(x => (x % 2 == 0))) + ) + } + + test("glom") { + assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") + val input = Seq(1 to 4, 5 to 8, 9 to 12) + val output = Seq( + Seq( Seq(1, 2), Seq(3, 4) ), + Seq( Seq(5, 6), Seq(7, 8) ), + Seq( Seq(9, 10), Seq(11, 12) ) + ) + val operation = (r: DStream[Int]) => r.glom().map(_.toSeq) + testOperation(input, operation, output) + } + + test("mapPartitions") { + assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") + val input = Seq(1 to 4, 5 to 8, 9 to 12) + val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23)) + val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _))) + testOperation(input, operation, output, true) + } + + test("groupByKey") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(), + Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ), + true + ) + } + test("reduceByKey") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), @@ -41,6 +79,54 @@ class DStreamBasicSuite extends DStreamSuiteBase { ) } + test("mapValues") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).mapValues(_ + 10), + Seq( Seq(("a", 12), ("b", 11)), Seq(("", 12)), Seq() ), + true + ) + } + + test("flatMapValues") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), + true + ) + } + + test("cogroup") { + val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() ) + val outputData = Seq( + Seq( ("a", (Seq(1, 1), Seq("x", "x"))), ("b", (Seq(1), Seq("x"))) ), + Seq( ("a", (Seq(1), Seq())), ("b", (Seq(), Seq("x"))), ("", (Seq(1), Seq("x"))) ), + Seq( ("", (Seq(1), Seq())) ), + Seq( ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + + test("join") { + val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) + val outputData = Seq( + Seq( ("a", (1, "x")), ("b", (1, "x")) ), + Seq( ("", (1, "x")) ), + Seq( ), + Seq( ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x,1)).join(s2.map(x => (x,"x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + test("updateStateByKey") { val inputData = Seq( diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala index 59fa5a6f22..2a4b37c965 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala @@ -6,7 +6,7 @@ import collection.mutable.ArrayBuffer import org.scalatest.FunSuite import collection.mutable.SynchronizedBuffer -class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, val input: Seq[Seq[T]]) +class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) extends InputDStream[T](ssc_) { var currentIndex = 0 @@ -17,9 +17,9 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, val input: Seq[ def compute(validTime: Time): Option[RDD[T]] = { logInfo("Computing RDD for time " + validTime) val rdd = if (currentIndex < input.size) { - ssc.sc.makeRDD(input(currentIndex), 2) + ssc.sc.makeRDD(input(currentIndex), numPartitions) } else { - ssc.sc.makeRDD(Seq[T](), 2) + ssc.sc.makeRDD(Seq[T](), numPartitions) } logInfo("Created RDD " + rdd.id) currentIndex += 1 @@ -47,6 +47,8 @@ trait DStreamSuiteBase extends FunSuite with Logging { def checkpointInterval() = batchDuration + def numInputPartitions() = 2 + def maxWaitTimeMillis() = 10000 def setupStreams[U: ClassManifest, V: ClassManifest]( @@ -62,7 +64,7 @@ trait DStreamSuiteBase extends FunSuite with Logging { } // Setup the stream computation - val inputStream = new TestInputStream(ssc, input) + val inputStream = new TestInputStream(ssc, input, numInputPartitions) val operatedStream = operation(inputStream) val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) ssc.registerInputStream(inputStream) @@ -70,6 +72,31 @@ trait DStreamSuiteBase extends FunSuite with Logging { ssc } + def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W] + ): StreamingContext = { + + // Create StreamingContext + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + if (checkpointFile != null) { + ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + } + + // Setup the stream computation + val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions) + val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) + val operatedStream = operation(inputStream1, inputStream2) + val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]]) + ssc.registerInputStream(inputStream1) + ssc.registerInputStream(inputStream2) + ssc.registerOutputStream(outputStream) + ssc + } + + def runStreams[V: ClassManifest]( ssc: StreamingContext, numBatches: Int, @@ -162,4 +189,28 @@ trait DStreamSuiteBase extends FunSuite with Logging { val output = runStreams[V](ssc, numBatches_, expectedOutput.size) verifyOutput[V](output, expectedOutput, useSet) } + + def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W], + expectedOutput: Seq[Seq[W]], + useSet: Boolean + ) { + testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet) + } + + def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W], + expectedOutput: Seq[Seq[W]], + numBatches: Int, + useSet: Boolean + ) { + val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size + val ssc = setupStreams[U, V, W](input1, input2, operation) + val output = runStreams[W](ssc, numBatches_, expectedOutput.size) + verifyOutput[W](output, expectedOutput, useSet) + } } -- cgit v1.2.3 From d4f2e5b0ef38db9d42bb0d5fbbbe6103ce047efe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 22 Oct 2012 10:28:59 -0700 Subject: Remove PYTHONPATH from SparkContext's executorEnvs. It makes more sense to pass it in the dictionary of environment variables that is used to construct PythonRDD. --- core/src/main/scala/spark/SparkContext.scala | 2 +- core/src/main/scala/spark/api/python/PythonRDD.scala | 15 +++++++-------- pyspark/pyspark/rdd.py | 8 ++++++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index acb38ae33d..becf737597 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -113,7 +113,7 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING", "PYTHONPATH")) { + "SPARK_TESTING")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 528885fe5c..a593e53efd 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -131,18 +131,17 @@ trait PythonRDDBase { } class PythonRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: Map[String, String], + parent: RDD[T], command: Seq[String], envVars: java.util.Map[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, - pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) - // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) + def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, + broadcastVars) override def splits = parent.splits @@ -151,7 +150,7 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[Array[Byte]] = - compute(split, envVars, command, parent, pythonExec, broadcastVars) + compute(split, envVars.toMap, command, parent, pythonExec, broadcastVars) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index e2137fe06c..e4878c08ba 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,6 +1,7 @@ from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap +import os import shlex from subprocess import Popen, PIPE from threading import Thread @@ -10,7 +11,7 @@ from pyspark.serializers import dump_pickle, load_pickle from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup -from py4j.java_collections import ListConverter +from py4j.java_collections import ListConverter, MapConverter class RDD(object): @@ -447,8 +448,11 @@ class PipelinedRDD(RDD): self.ctx.gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() class_manifest = self._prev_jrdd.classManifest() + env = MapConverter().convert( + {'PYTHONPATH' : os.environ.get("PYTHONPATH", "")}, + self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val -- cgit v1.2.3 From 2c87c853ba24f55c142e4864b14c62d0a82a82df Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 22 Oct 2012 15:31:19 -0700 Subject: Renamed examples --- .../spark/streaming/examples/ExampleOne.scala | 41 ------------------- .../spark/streaming/examples/ExampleTwo.scala | 47 ---------------------- .../spark/streaming/examples/FileStream.scala | 47 ++++++++++++++++++++++ .../spark/streaming/examples/QueueStream.scala | 41 +++++++++++++++++++ 4 files changed, 88 insertions(+), 88 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/FileStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/QueueStream.scala diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala deleted file mode 100644 index 2ff8790e77..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/ExampleOne.scala +++ /dev/null @@ -1,41 +0,0 @@ -package spark.streaming.examples - -import spark.RDD -import spark.streaming.StreamingContext -import spark.streaming.StreamingContext._ -import spark.streaming.Seconds - -import scala.collection.mutable.SynchronizedQueue - -object ExampleOne { - - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: ExampleOne ") - System.exit(1) - } - - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "ExampleOne") - ssc.setBatchDuration(Seconds(1)) - - // Create the queue through which RDDs can be pushed to - // a QueueInputDStream - val rddQueue = new SynchronizedQueue[RDD[Int]]() - - // Create the QueueInputDStream and use it do some processing - val inputStream = ssc.createQueueStream(rddQueue) - val mappedStream = inputStream.map(x => (x % 10, 1)) - val reducedStream = mappedStream.reduceByKey(_ + _) - reducedStream.print() - ssc.start() - - // Create and push some RDDs into - for (i <- 1 to 30) { - rddQueue += ssc.sc.makeRDD(1 to 1000, 10) - Thread.sleep(1000) - } - ssc.stop() - System.exit(0) - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala b/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala deleted file mode 100644 index ad563e2c75..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/ExampleTwo.scala +++ /dev/null @@ -1,47 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.StreamingContext -import spark.streaming.StreamingContext._ -import spark.streaming.Seconds -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - - -object ExampleTwo { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: ExampleOne ") - System.exit(1) - } - - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "ExampleTwo") - ssc.setBatchDuration(Seconds(2)) - - // Create the new directory - val directory = new Path(args(1)) - val fs = directory.getFileSystem(new Configuration()) - if (fs.exists(directory)) throw new Exception("This directory already exists") - fs.mkdirs(directory) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val inputStream = ssc.createTextFileStream(directory.toString) - val words = inputStream.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - - // Creating new files in the directory - val text = "This is a text file" - for (i <- 1 to 30) { - ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) - .saveAsTextFile(new Path(directory, i.toString).toString) - Thread.sleep(1000) - } - Thread.sleep(5000) // Waiting for the file to be processed - ssc.stop() - fs.delete(directory) - System.exit(0) - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala new file mode 100644 index 0000000000..301da56014 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala @@ -0,0 +1,47 @@ +package spark.streaming.examples + +import spark.streaming.StreamingContext +import spark.streaming.StreamingContext._ +import spark.streaming.Seconds +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + + +object FileStream { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: FileStream ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "FileStream") + ssc.setBatchDuration(Seconds(2)) + + // Create the new directory + val directory = new Path(args(1)) + val fs = directory.getFileSystem(new Configuration()) + if (fs.exists(directory)) throw new Exception("This directory already exists") + fs.mkdirs(directory) + fs.deleteOnExit(directory) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val inputStream = ssc.createTextFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + + // Creating new files in the directory + val text = "This is a text file" + for (i <- 1 to 30) { + ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) + .saveAsTextFile(new Path(directory, i.toString).toString) + Thread.sleep(1000) + } + Thread.sleep(5000) // Waiting for the file to be processed + ssc.stop() + System.exit(0) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala new file mode 100644 index 0000000000..ae701bba6d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala @@ -0,0 +1,41 @@ +package spark.streaming.examples + +import spark.RDD +import spark.streaming.StreamingContext +import spark.streaming.StreamingContext._ +import spark.streaming.Seconds + +import scala.collection.mutable.SynchronizedQueue + +object QueueStream { + + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: QueueStream ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "QueueStream") + ssc.setBatchDuration(Seconds(1)) + + // Create the queue through which RDDs can be pushed to + // a QueueInputDStream + val rddQueue = new SynchronizedQueue[RDD[Int]]() + + // Create the QueueInputDStream and use it do some processing + val inputStream = ssc.createQueueStream(rddQueue) + val mappedStream = inputStream.map(x => (x % 10, 1)) + val reducedStream = mappedStream.reduceByKey(_ + _) + reducedStream.print() + ssc.start() + + // Create and push some RDDs into + for (i <- 1 to 30) { + rddQueue += ssc.sc.makeRDD(1 to 1000, 10) + Thread.sleep(1000) + } + ssc.stop() + System.exit(0) + } +} \ No newline at end of file -- cgit v1.2.3 From a6de5758f1a48e6c25b441440d8cd84546857326 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 23 Oct 2012 01:41:13 -0700 Subject: Modified API of NetworkInputDStreams and got ObjectInputDStream and RawInputDStream working. --- .../scala/spark/streaming/FileInputDStream.scala | 18 -- .../spark/streaming/NetworkInputDStream.scala | 139 +++++++++++- .../streaming/NetworkInputReceiverMessage.scala | 7 - .../spark/streaming/NetworkInputTracker.scala | 84 +++---- .../scala/spark/streaming/ObjectInputDStream.scala | 169 +++++++++++++- .../spark/streaming/ObjectInputReceiver.scala | 244 --------------------- .../scala/spark/streaming/RawInputDStream.scala | 77 ++----- .../scala/spark/streaming/StreamingContext.scala | 4 +- 8 files changed, 360 insertions(+), 382 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala index 29ae89616e..78537b8794 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -19,15 +19,6 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null - /* - @transient @noinline lazy val path = { - //if (directory == null) throw new Exception("directory is null") - //println(directory) - new Path(directory) - } - @transient lazy val fs = path.getFileSystem(new Configuration()) - */ - var lastModTime: Long = 0 def path(): Path = { @@ -79,15 +70,6 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) Some(newRDD) } - /* - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - println(this.getClass().getSimpleName + ".readObject used") - ois.defaultReadObject() - println("HERE HERE" + this.directory) - } - */ - } object FileInputDStream { diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index bf83f98ec4..6b41e4d2c8 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -1,13 +1,30 @@ package spark.streaming -import spark.RDD -import spark.BlockRDD +import spark.{Logging, SparkEnv, RDD, BlockRDD} +import spark.storage.StorageLevel -abstract class NetworkInputDStream[T: ClassManifest](@transient ssc: StreamingContext) - extends InputDStream[T](ssc) { +import java.nio.ByteBuffer - val id = ssc.getNewNetworkStreamId() - +import akka.actor.{Props, Actor} +import akka.pattern.ask +import akka.dispatch.Await +import akka.util.duration._ + +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) + extends InputDStream[T](ssc_) { + + // This is an unique identifier that is used to match the network receiver with the + // corresponding network input stream. + val id = ssc.getNewNetworkStreamId() + + /** + * This method creates the receiver object that will be sent to the workers + * to receive data. This method needs to defined by any specific implementation + * of a NetworkInputDStream. + */ + def createReceiver(): NetworkReceiver[T] + + // Nothing to start or stop as both taken care of by the NetworkInputTracker. def start() {} def stop() {} @@ -16,8 +33,114 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc: StreamingCo val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) Some(new BlockRDD[T](ssc.sc, blockIds)) } +} + + +sealed trait NetworkReceiverMessage +case class StopReceiver(msg: String) extends NetworkReceiverMessage +case class ReportBlock(blockId: String) extends NetworkReceiverMessage +case class ReportError(msg: String) extends NetworkReceiverMessage + +abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging { + + initLogging() + + lazy protected val env = SparkEnv.get + + lazy protected val actor = env.actorSystem.actorOf( + Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + + lazy protected val receivingThread = Thread.currentThread() + + /** This method will be called to start receiving data. */ + protected def onStart() + + /** This method will be called to stop receiving data. */ + protected def onStop() + + /** + * This method starts the receiver. First is accesses all the lazy members to + * materialize them. Then it calls the user-defined onStart() method to start + * other threads, etc required to receiver the data. + */ + def start() { + try { + // Access the lazy vals to materialize them + env + actor + receivingThread + + // Call user-defined onStart() + onStart() + } catch { + case ie: InterruptedException => + logWarning("Receiving thread interrupted") + case e: Exception => + stopOnError(e) + } + } + + /** + * This method stops the receiver. First it interrupts the main receiving thread, + * that is, the thread that called receiver.start(). Then it calls the user-defined + * onStop() method to stop other threads and/or do cleanup. + */ + def stop() { + receivingThread.interrupt() + onStop() + //TODO: terminate the actor + } + + /** + * This method stops the receiver and reports to exception to the tracker. + * This should be called whenever an exception has happened on any thread + * of the receiver. + */ + protected def stopOnError(e: Exception) { + logError("Error receiving data", e) + stop() + actor ! ReportError(e.toString) + } - /** Called on workers to run a receiver for the stream. */ - def runReceiver(): Unit + /** + * This method pushes a block (as iterator of values) into the block manager. + */ + protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) { + env.blockManager.put(blockId, iterator, level) + actor ! ReportBlock(blockId) + } + + /** + * This method pushes a block (as bytes) into the block manager. + */ + protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + env.blockManager.putBytes(blockId, bytes, level) + actor ! ReportBlock(blockId) + } + + /** A helper actor that communicates with the NetworkInputTracker */ + private class NetworkReceiverActor extends Actor { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + val tracker = env.actorSystem.actorFor(url) + val timeout = 5.seconds + + override def preStart() { + val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive() = { + case ReportBlock(blockId) => + tracker ! AddBlocks(streamId, Array(blockId)) + case ReportError(msg) => + tracker ! DeregisterReceiver(streamId, msg) + case StopReceiver(msg) => + stop() + tracker ! DeregisterReceiver(streamId, msg) + } + } } diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala b/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala deleted file mode 100644 index deaffe98c8..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkInputReceiverMessage.scala +++ /dev/null @@ -1,7 +0,0 @@ -package spark.streaming - -sealed trait NetworkInputReceiverMessage - -case class GetBlockIds(time: Long) extends NetworkInputReceiverMessage -case class GotBlockIds(streamId: Int, blocksIds: Array[String]) extends NetworkInputReceiverMessage -case object StopReceiver extends NetworkInputReceiverMessage diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 9f9001e4d5..9b1b8813de 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -13,13 +13,44 @@ import akka.dispatch._ trait NetworkInputTrackerMessage case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage +case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage +case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage + class NetworkInputTracker( @transient ssc: StreamingContext, - @transient networkInputStreams: Array[NetworkInputDStream[_]]) -extends Logging { + @transient networkInputStreams: Array[NetworkInputDStream[_]]) + extends Logging { + + val networkInputStreamIds = networkInputStreams.map(_.id).toArray + val receiverExecutor = new ReceiverExecutor() + val receiverInfo = new HashMap[Int, ActorRef] + val receivedBlockIds = new HashMap[Int, Queue[String]] + val timeout = 5000.milliseconds - class TrackerActor extends Actor { + var currentTime: Time = null + + def start() { + ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") + receiverExecutor.start() + } + + def stop() { + receiverExecutor.interrupt() + receiverExecutor.stopReceivers() + } + + def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { + val queue = receivedBlockIds.synchronized { + receivedBlockIds.getOrElse(receiverId, new Queue[String]()) + } + val result = queue.synchronized { + queue.dequeueAll(x => true) + } + result.toArray + } + + private class NetworkInputTrackerActor extends Actor { def receive = { case RegisterReceiver(streamId, receiverActor) => { if (!networkInputStreamIds.contains(streamId)) { @@ -29,7 +60,7 @@ extends Logging { logInfo("Registered receiver for network stream " + streamId) sender ! true } - case GotBlockIds(streamId, blockIds) => { + case AddBlocks(streamId, blockIds) => { val tmp = receivedBlockIds.synchronized { if (!receivedBlockIds.contains(streamId)) { receivedBlockIds += ((streamId, new Queue[String])) @@ -40,6 +71,12 @@ extends Logging { tmp ++= blockIds } } + case DeregisterReceiver(streamId, msg) => { + receiverInfo -= streamId + logInfo("De-registered receiver for network stream " + streamId + + " with message " + msg) + //TODO: Do something about the corresponding NetworkInputDStream + } } } @@ -58,15 +95,15 @@ extends Logging { } def startReceivers() { - val tempRDD = ssc.sc.makeRDD(networkInputStreams, networkInputStreams.size) - - val startReceiver = (iterator: Iterator[NetworkInputDStream[_]]) => { + val receivers = networkInputStreams.map(_.createReceiver()) + val tempRDD = ssc.sc.makeRDD(receivers, receivers.size) + + val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { if (!iterator.hasNext) { throw new Exception("Could not start receiver as details not found.") } - iterator.next().runReceiver() + iterator.next().start() } - ssc.sc.runJob(tempRDD, startReceiver) } @@ -77,33 +114,4 @@ extends Logging { Await.result(futureOfList, timeout) } } - - val networkInputStreamIds = networkInputStreams.map(_.id).toArray - val receiverExecutor = new ReceiverExecutor() - val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Queue[String]] - val timeout = 5000.milliseconds - - - var currentTime: Time = null - - def start() { - ssc.env.actorSystem.actorOf(Props(new TrackerActor), "NetworkInputTracker") - receiverExecutor.start() - } - - def stop() { - // stop the actor - receiverExecutor.interrupt() - } - - def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { - val queue = receivedBlockIds.synchronized { - receivedBlockIds.getOrElse(receiverId, new Queue[String]()) - } - val result = queue.synchronized { - queue.dequeueAll(x => true) - } - result.toArray - } } diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala index 2396b374a0..89aeeda8b3 100644 --- a/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala @@ -1,16 +1,167 @@ package spark.streaming -import java.io.InputStream +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + +import java.io.{EOFException, DataInputStream, BufferedInputStream, InputStream} +import java.net.Socket +import java.util.concurrent.ArrayBlockingQueue + +import scala.collection.mutable.ArrayBuffer class ObjectInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, - val host: String, - val port: Int, - val bytesToObjects: InputStream => Iterator[T]) - extends NetworkInputDStream[T](ssc) { - - override def runReceiver() { - new ObjectInputReceiver(id, host, port, bytesToObjects).run() + @transient ssc_ : StreamingContext, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_) { + + def createReceiver(): NetworkReceiver[T] = { + new ObjectInputReceiver(id, host, port, bytesToObjects, storageLevel) } } + +class ObjectInputReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkReceiver[T](streamId) { + + lazy protected val dataHandler = new DataHandler(this) + + protected def onStart() { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + dataHandler.start() + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } + + protected def onStop() { + dataHandler.stop() + } + + /** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler(receiver: NetworkReceiver[T]) extends Serializable { + case class Block(id: String, iterator: Iterator[T]) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + streamId + "- " + (time - blockInterval) + val newBlock = new Block(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + pushBlock(block.id, block.iterator, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + } +} + + +object ObjectInputReceiver { + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val bufferedInputStream = new BufferedInputStream(inputStream) + val dataInputStream = new DataInputStream(bufferedInputStream) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + println("[" + nextValue + "]") + } catch { + case eof: EOFException => + finished = true + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!gotNext) { + getNext() + } + if (finished) { + dataInputStream.close() + } + !finished + } + + override def next(): String = { + if (!gotNext) { + getNext() + } + if (finished) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } + } + iterator + } +} diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala b/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala deleted file mode 100644 index 70fa2cdf07..0000000000 --- a/streaming/src/main/scala/spark/streaming/ObjectInputReceiver.scala +++ /dev/null @@ -1,244 +0,0 @@ -package spark.streaming - -import spark.Logging -import spark.storage.BlockManager -import spark.storage.StorageLevel -import spark.SparkEnv -import spark.streaming.util.SystemClock -import spark.streaming.util.RecurringTimer - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Queue -import scala.collection.mutable.SynchronizedPriorityQueue -import scala.math.Ordering - -import java.net.InetSocketAddress -import java.net.Socket -import java.io.InputStream -import java.io.BufferedInputStream -import java.io.DataInputStream -import java.io.EOFException -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.ArrayBlockingQueue - -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ - -class ObjectInputReceiver[T: ClassManifest]( - streamId: Int, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T]) - extends Logging { - - class ReceiverActor extends Actor { - override def preStart() { - logInfo("Attempting to register") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 1.seconds - val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - def receive = { - case GetBlockIds(time) => { - logInfo("Got request for block ids for " + time) - sender ! GotBlockIds(streamId, dataHandler.getPushedBlocks()) - } - - case StopReceiver => { - if (receivingThread != null) { - receivingThread.interrupt() - } - sender ! true - } - } - } - - class DataHandler { - class Block(val time: Long, val iterator: Iterator[T]) { - val blockId = "input-" + streamId + "-" + time - var pushed = true - override def toString = "input block " + blockId - } - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockOrdering = new Ordering[Block] { - def compare(b1: Block, b2: Block) = (b1.time - b2.time).toInt - } - val blockStorageLevel = StorageLevel.DISK_AND_MEMORY - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blocksForReporting = new SynchronizedPriorityQueue[Block]()(blockOrdering) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val newBlock = new Block(time - blockInterval, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - blocksForReporting.enqueue(newBlock) - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - if (blockManager != null) { - blockManager.put(block.blockId, block.iterator, blockStorageLevel) - block.pushed = true - } else { - logWarning(block + " not put as block manager is null") - } - } - } catch { - case ie: InterruptedException => println("Block pushing thread interrupted") - case e: Exception => e.printStackTrace() - } - } - - def getPushedBlocks(): Array[String] = { - val pushedBlocks = new ArrayBuffer[String]() - var loop = true - while(loop && !blocksForReporting.isEmpty) { - val block = blocksForReporting.dequeue() - if (block == null) { - loop = false - } else if (!block.pushed) { - blocksForReporting.enqueue(block) - } else { - pushedBlocks += block.blockId - } - } - logInfo("Got " + pushedBlocks.size + " blocks") - pushedBlocks.toArray - } - } - - val blockManager = if (SparkEnv.get != null) SparkEnv.get.blockManager else null - val dataHandler = new DataHandler() - val env = SparkEnv.get - - var receiverActor: ActorRef = null - var receivingThread: Thread = null - - def run() { - initLogging() - var socket: Socket = null - try { - if (SparkEnv.get != null) { - receiverActor = SparkEnv.get.actorSystem.actorOf(Props(new ReceiverActor), "ReceiverActor-" + streamId) - } - dataHandler.start() - socket = connect() - receivingThread = Thread.currentThread() - receive(socket) - } catch { - case ie: InterruptedException => logInfo("Receiver interrupted") - } finally { - receivingThread = null - if (socket != null) socket.close() - dataHandler.stop() - } - } - - def connect(): Socket = { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - socket - } - - def receive(socket: Socket) { - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - dataHandler += obj - } - } -} - - -object ObjectInputReceiver { - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val bufferedInputStream = new BufferedInputStream(inputStream) - val dataInputStream = new DataInputStream(bufferedInputStream) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - println("[" + nextValue + "]") - } catch { - case eof: EOFException => - finished = true - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!gotNext) { - getNext() - } - if (finished) { - dataInputStream.close() - } - !finished - } - - override def next(): String = { - if (!gotNext) { - getNext() - } - if (finished) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - } - iterator - } - - def main(args: Array[String]) { - if (args.length < 2) { - println("ObjectInputReceiver ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val receiver = new ObjectInputReceiver(0, host, port, bytesToLines) - receiver.run() - } -} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index d29aea7886..e022b85fbe 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -1,16 +1,11 @@ package spark.streaming -import akka.actor._ -import akka.pattern.ask -import akka.util.duration._ -import akka.dispatch._ import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, SocketChannel} import java.io.EOFException import java.util.concurrent.ArrayBlockingQueue -import scala.collection.mutable.ArrayBuffer -import spark.{DaemonThread, Logging, SparkEnv} +import spark._ import spark.storage.StorageLevel /** @@ -20,20 +15,23 @@ import spark.storage.StorageLevel * in the format that the system is configured with. */ class RawInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, + @transient ssc_ : StreamingContext, host: String, port: Int, - storageLevel: StorageLevel) - extends NetworkInputDStream[T](ssc) with Logging { + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { - val streamId = id + def createReceiver(): NetworkReceiver[T] = { + new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] + } +} + +class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) + extends NetworkReceiver[Any](streamId) { - /** Called on workers to run a receiver for the stream. */ - def runReceiver() { - val env = SparkEnv.get - val actor = env.actorSystem.actorOf( - Props(new ReceiverActor(env, Thread.currentThread)), "ReceiverActor-" + streamId) + var blockPushingThread: Thread = null + def onStart() { // Open a socket to the target address and keep reading from it logInfo("Connecting to " + host + ":" + port) val channel = SocketChannel.open() @@ -43,18 +41,18 @@ class RawInputDStream[T: ClassManifest]( val queue = new ArrayBlockingQueue[ByteBuffer](2) - new DaemonThread { + blockPushingThread = new DaemonThread { override def run() { var nextBlockNumber = 0 while (true) { val buffer = queue.take() val blockId = "input-" + streamId + "-" + nextBlockNumber nextBlockNumber += 1 - env.blockManager.putBytes(blockId, buffer, storageLevel) - actor ! BlockPublished(blockId) + pushBlock(blockId, buffer, storageLevel) } } - }.start() + } + blockPushingThread.start() val lengthBuffer = ByteBuffer.allocate(4) while (true) { @@ -70,6 +68,10 @@ class RawInputDStream[T: ClassManifest]( } } + def onStop() { + blockPushingThread.interrupt() + } + /** Read a buffer fully from a given Channel */ private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { while (dest.position < dest.limit) { @@ -78,41 +80,4 @@ class RawInputDStream[T: ClassManifest]( } } } - - /** Message sent to ReceiverActor to tell it that a block was published */ - case class BlockPublished(blockId: String) {} - - /** A helper actor that communicates with the NetworkInputTracker */ - private class ReceiverActor(env: SparkEnv, receivingThread: Thread) extends Actor { - val newBlocks = new ArrayBuffer[String] - - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 5.seconds - - override def preStart() { - val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - override def receive = { - case BlockPublished(blockId) => - newBlocks += blockId - val future = trackerActor ! GotBlockIds(streamId, Array(blockId)) - - case GetBlockIds(time) => - logInfo("Got request for block IDs for " + time) - sender ! GotBlockIds(streamId, newBlocks.toArray) - newBlocks.clear() - - case StopReceiver => - receivingThread.interrupt() - sender ! true - } - - } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 7022056f7c..1dc5614a5c 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -95,7 +95,7 @@ class StreamingContext ( port: Int, converter: (InputStream) => Iterator[T] ): DStream[T] = { - val inputStream = new ObjectInputDStream[T](this, hostname, port, converter) + val inputStream = new ObjectInputDStream[T](this, hostname, port, converter, StorageLevel.DISK_AND_MEMORY_2) graph.addInputStream(inputStream) inputStream } @@ -207,7 +207,7 @@ class StreamingContext ( } /** - * This function starts the execution of the streams. + * This function stops the execution of the streams. */ def stop() { try { -- cgit v1.2.3 From 19191d178d194e4b57094ca868e1cc9c66b8d4a7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 23 Oct 2012 14:40:24 -0700 Subject: Renamed the network input streams. --- .../scala/spark/streaming/ObjectInputDStream.scala | 167 --------------------- .../scala/spark/streaming/SocketInputDStream.scala | 167 +++++++++++++++++++++ .../scala/spark/streaming/StreamingContext.scala | 63 ++++---- .../scala/spark/streaming/examples/CountRaw.scala | 2 +- .../scala/spark/streaming/examples/GrepRaw.scala | 2 +- .../streaming/examples/TopKWordCountRaw.scala | 2 +- .../streaming/examples/WordCountNetwork.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 2 +- 8 files changed, 206 insertions(+), 201 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/SocketInputDStream.scala diff --git a/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala b/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala deleted file mode 100644 index 89aeeda8b3..0000000000 --- a/streaming/src/main/scala/spark/streaming/ObjectInputDStream.scala +++ /dev/null @@ -1,167 +0,0 @@ -package spark.streaming - -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - -import java.io.{EOFException, DataInputStream, BufferedInputStream, InputStream} -import java.net.Socket -import java.util.concurrent.ArrayBlockingQueue - -import scala.collection.mutable.ArrayBuffer - -class ObjectInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_) { - - def createReceiver(): NetworkReceiver[T] = { - new ObjectInputReceiver(id, host, port, bytesToObjects, storageLevel) - } -} - - -class ObjectInputReceiver[T: ClassManifest]( - streamId: Int, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkReceiver[T](streamId) { - - lazy protected val dataHandler = new DataHandler(this) - - protected def onStart() { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - dataHandler.start() - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - dataHandler += obj - } - } - - protected def onStop() { - dataHandler.stop() - } - - /** - * This is a helper object that manages the data received from the socket. It divides - * the object received into small batches of 100s of milliseconds, pushes them as - * blocks into the block manager and reports the block IDs to the network input - * tracker. It starts two threads, one to periodically start a new batch and prepare - * the previous batch of as a block, the other to push the blocks into the block - * manager. - */ - class DataHandler(receiver: NetworkReceiver[T]) extends Serializable { - case class Block(id: String, iterator: Iterator[T]) - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + streamId + "- " + (time - blockInterval) - val newBlock = new Block(blockId, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - receiver.stop() - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - pushBlock(block.id, block.iterator, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - receiver.stop() - } - } - } -} - - -object ObjectInputReceiver { - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val bufferedInputStream = new BufferedInputStream(inputStream) - val dataInputStream = new DataInputStream(bufferedInputStream) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - println("[" + nextValue + "]") - } catch { - case eof: EOFException => - finished = true - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!gotNext) { - getNext() - } - if (finished) { - dataInputStream.close() - } - !finished - } - - override def next(): String = { - if (!gotNext) { - getNext() - } - if (finished) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - } - iterator - } -} diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala new file mode 100644 index 0000000000..4dbf421687 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -0,0 +1,167 @@ +package spark.streaming + +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + +import java.io.{EOFException, DataInputStream, BufferedInputStream, InputStream} +import java.net.Socket +import java.util.concurrent.ArrayBlockingQueue + +import scala.collection.mutable.ArrayBuffer + +class SocketInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_) { + + def createReceiver(): NetworkReceiver[T] = { + new ObjectInputReceiver(id, host, port, bytesToObjects, storageLevel) + } +} + + +class ObjectInputReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkReceiver[T](streamId) { + + lazy protected val dataHandler = new DataHandler(this) + + protected def onStart() { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + dataHandler.start() + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } + + protected def onStop() { + dataHandler.stop() + } + + /** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler(receiver: NetworkReceiver[T]) extends Serializable { + case class Block(id: String, iterator: Iterator[T]) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + streamId + "- " + (time - blockInterval) + val newBlock = new Block(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + pushBlock(block.id, block.iterator, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + } +} + + +object ObjectInputReceiver { + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val bufferedInputStream = new BufferedInputStream(inputStream) + val dataInputStream = new DataInputStream(bufferedInputStream) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + println("[" + nextValue + "]") + } catch { + case eof: EOFException => + finished = true + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!gotNext) { + getNext() + } + if (finished) { + dataInputStream.close() + } + !finished + } + + override def next(): String = { + if (!gotNext) { + getNext() + } + if (finished) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } + } + iterator + } +} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 1dc5614a5c..90654cdad9 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -86,21 +86,26 @@ class StreamingContext ( private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() - def createNetworkTextStream(hostname: String, port: Int): DStream[String] = { - createNetworkObjectStream[String](hostname, port, ObjectInputReceiver.bytesToLines) + def networkTextStream( + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.DISK_AND_MEMORY_2 + ): DStream[String] = { + networkStream[String](hostname, port, ObjectInputReceiver.bytesToLines, storageLevel) } - - def createNetworkObjectStream[T: ClassManifest]( - hostname: String, - port: Int, - converter: (InputStream) => Iterator[T] + + def networkStream[T: ClassManifest]( + hostname: String, + port: Int, + converter: (InputStream) => Iterator[T], + storageLevel: StorageLevel ): DStream[T] = { - val inputStream = new ObjectInputDStream[T](this, hostname, port, converter, StorageLevel.DISK_AND_MEMORY_2) + val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel) graph.addInputStream(inputStream) inputStream } - - def createRawNetworkStream[T: ClassManifest]( + + def rawNetworkStream[T: ClassManifest]( hostname: String, port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_2 @@ -109,26 +114,26 @@ class StreamingContext ( graph.addInputStream(inputStream) inputStream } - + /* def createHttpTextStream(url: String): DStream[String] = { createHttpStream(url, ObjectInputReceiver.bytesToLines) } - + def createHttpStream[T: ClassManifest]( - url: String, + url: String, converter: (InputStream) => Iterator[T] ): DStream[T] = { } */ - /** + /** * This function creates a input stream that monitors a Hadoop-compatible * for new files and executes the necessary processing on them. - */ + */ def createFileStream[ - K: ClassManifest, - V: ClassManifest, + K: ClassManifest, + V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest ](directory: String): DStream[(K, V)] = { val inputStream = new FileInputDStream[K, V, F](this, directory) @@ -139,13 +144,13 @@ class StreamingContext ( def createTextFileStream(directory: String): DStream[String] = { createFileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } - + /** * This function create a input stream from an queue of RDDs. In each batch, - * it will process either one or all of the RDDs returned by the queue + * it will process either one or all of the RDDs returned by the queue */ def createQueueStream[T: ClassManifest]( - queue: Queue[RDD[T]], + queue: Queue[RDD[T]], oneAtATime: Boolean = true, defaultRDD: RDD[T] = null ): DStream[T] = { @@ -153,7 +158,7 @@ class StreamingContext ( graph.addInputStream(inputStream) inputStream } - + def createQueueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = { val queue = new Queue[RDD[T]] val inputStream = createQueueStream(queue, true, null) @@ -172,27 +177,27 @@ class StreamingContext ( /** * This function registers a DStream as an output stream that will be * computed every interval. - */ + */ def registerOutputStream(outputStream: DStream[_]) { graph.addOutputStream(outputStream) } - + def validate() { assert(graph != null, "Graph is null") graph.validate() } /** - * This function starts the execution of the streams. - */ + * This function starts the execution of the streams. + */ def start() { validate() val networkInputStreams = graph.getInputStreams().filter(s => s match { - case n: NetworkInputDStream[_] => true + case n: NetworkInputDStream[_] => true case _ => false }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray - + if (networkInputStreams.length > 0) { // Start the network input tracker (must start before receivers) networkInputTracker = new NetworkInputTracker(this, networkInputStreams) @@ -203,9 +208,9 @@ class StreamingContext ( // Start the scheduler scheduler = new Scheduler(this) - scheduler.start() + scheduler.start() } - + /** * This function stops the execution of the streams. */ diff --git a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala index ed571d22e3..d2fdabd659 100644 --- a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala @@ -24,7 +24,7 @@ object CountRaw { ssc.sc.parallelize(1 to 1000, 1000).count() val rawStreams = (1 to numStreams).map(_ => - ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnionDStream(rawStreams) union.map(_.length + 2).reduce(_ + _).foreachRDD(r => println("Byte count: " + r.collect().mkString)) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index 6af1c36891..b1e1a613fe 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -24,7 +24,7 @@ object GrepRaw { ssc.sc.parallelize(1 to 1000, 1000).count() val rawStreams = (1 to numStreams).map(_ => - ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnionDStream(rawStreams) union.filter(_.contains("Culpepper")).count().foreachRDD(r => println("Grep count: " + r.collect().mkString)) diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index af0a3bf98a..9d1b0b9eb4 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -36,7 +36,7 @@ object TopKWordCountRaw { moreWarmup(ssc.sc) val rawStreams = (1 to streams).map(_ => - ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnionDStream(rawStreams) val windowedCounts = union.mapPartitions(splitAndCountPartitions) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala index 0aa5294a17..ba1bd1de7c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala @@ -16,7 +16,7 @@ object WordCountNetwork { // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.createNetworkTextStream(args(1), args(2).toInt) + val lines = ssc.networkTextStream(args(1), args(2).toInt) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 98bafec529..d8a0664d7d 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -36,7 +36,7 @@ object WordCountRaw { moreWarmup(ssc.sc) val rawStreams = (1 to streams).map(_ => - ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnionDStream(rawStreams) val windowedCounts = union.mapPartitions(splitAndCountPartitions) -- cgit v1.2.3 From c2731dd3effe780d7f37487f8cbd27179055ebee Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 23 Oct 2012 15:10:27 -0700 Subject: Updated StateDStream api to use Options instead of nulls. --- .../main/scala/spark/streaming/PairDStreamFunctions.scala | 14 +++++++------- .../src/main/scala/spark/streaming/StateDStream.scala | 15 +++++++++------ .../test/scala/spark/streaming/DStreamBasicSuite.scala | 7 ++----- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 5de57eb2fd..ce1f4ad0a0 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -156,30 +156,30 @@ extends Serializable { // // def updateStateByKey[S <: AnyRef : ClassManifest]( - updateFunc: (Seq[V], S) => S + updateFunc: (Seq[V], Option[S]) => Option[S] ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner()) } def updateStateByKey[S <: AnyRef : ClassManifest]( - updateFunc: (Seq[V], S) => S, + updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) } def updateStateByKey[S <: AnyRef : ClassManifest]( - updateFunc: (Seq[V], S) => S, + updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner ): DStream[(K, S)] = { - val func = (iterator: Iterator[(K, Seq[V], S)]) => { - iterator.map(tuple => (tuple._1, updateFunc(tuple._2, tuple._3))) + val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { + iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) } - updateStateByKey(func, partitioner, true) + updateStateByKey(newUpdateFunc, partitioner, true) } def updateStateByKey[S <: AnyRef : ClassManifest]( - updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean ): DStream[(K, S)] = { diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index d223f25dfc..3ba8fb45fb 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -8,14 +8,17 @@ import spark.SparkContext._ import spark.storage.StorageLevel -class StateRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U], rememberPartitioner: Boolean) - extends MapPartitionsRDD[U, T](prev, f) { +class StateRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: Iterator[T] => Iterator[U], + rememberPartitioner: Boolean + ) extends MapPartitionsRDD[U, T](prev, f) { override val partitioner = if (rememberPartitioner) prev.partitioner else None } class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( - @transient parent: DStream[(K, V)], - updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], + parent: DStream[(K, V)], + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { @@ -82,7 +85,7 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) + (t._1, t._2._1, t._2._2.headOption) }) updateFuncLocal(i) } @@ -108,7 +111,7 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // and then apply the update function val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { - updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) } val groupedRDD = parentRDD.groupByKey(partitioner) diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala index db95c2cfaa..290a216797 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -149,11 +149,8 @@ class DStreamBasicSuite extends DStreamSuiteBase { ) val updateStateOperation = (s: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: RichInt) => { - var newState = 0 - if (values != null && values.size > 0) newState += values.reduce(_ + _) - if (state != null) newState += state.self - new RichInt(newState) + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) } s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) } -- cgit v1.2.3 From 0e5d9be4dfe0d072db8410fe6d254555bba9367d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 23 Oct 2012 15:17:05 -0700 Subject: Renamed APIs to create queueStream and fileStream. --- .../src/main/scala/spark/streaming/StreamingContext.scala | 12 ++++++------ .../src/main/scala/spark/streaming/examples/FileStream.scala | 2 +- .../spark/streaming/examples/FileStreamWithCheckpoint.scala | 2 +- .../main/scala/spark/streaming/examples/QueueStream.scala | 2 +- .../main/scala/spark/streaming/examples/WordCountHdfs.scala | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 90654cdad9..228f1a3616 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -131,7 +131,7 @@ class StreamingContext ( * This function creates a input stream that monitors a Hadoop-compatible * for new files and executes the necessary processing on them. */ - def createFileStream[ + def fileStream[ K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest @@ -141,15 +141,15 @@ class StreamingContext ( inputStream } - def createTextFileStream(directory: String): DStream[String] = { - createFileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) + def textFileStream(directory: String): DStream[String] = { + fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } /** * This function create a input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue */ - def createQueueStream[T: ClassManifest]( + def queueStream[T: ClassManifest]( queue: Queue[RDD[T]], oneAtATime: Boolean = true, defaultRDD: RDD[T] = null @@ -159,9 +159,9 @@ class StreamingContext ( inputStream } - def createQueueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = { + def queueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = { val queue = new Queue[RDD[T]] - val inputStream = createQueueStream(queue, true, null) + val inputStream = queueStream(queue, true, null) queue ++= array inputStream } diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala index 301da56014..d68611abd6 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala @@ -27,7 +27,7 @@ object FileStream { // Create the FileInputDStream on the directory and use the // stream to count words in new files created - val inputStream = ssc.createTextFileStream(directory.toString) + val inputStream = ssc.textFileStream(directory.toString) val words = inputStream.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala index c725035a8a..df96a811da 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -37,7 +37,7 @@ object FileStreamWithCheckpoint { ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) // Setup the streaming computation - val inputStream = ssc_.createTextFileStream(directory.toString) + val inputStream = ssc_.textFileStream(directory.toString) val words = inputStream.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() diff --git a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala index ae701bba6d..2af51bad28 100644 --- a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala +++ b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala @@ -24,7 +24,7 @@ object QueueStream { val rddQueue = new SynchronizedQueue[RDD[Int]]() // Create the QueueInputDStream and use it do some processing - val inputStream = ssc.createQueueStream(rddQueue) + val inputStream = ssc.queueStream(rddQueue) val mappedStream = inputStream.map(x => (x % 10, 1)) val reducedStream = mappedStream.reduceByKey(_ + _) reducedStream.print() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala index 3b86948822..591cb141c3 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala @@ -16,7 +16,7 @@ object WordCountHdfs { // Create the FileInputDStream on the directory and use the // stream to count words in new files created - val lines = ssc.createTextFileStream(args(1)) + val lines = ssc.textFileStream(args(1)) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() -- cgit v1.2.3 From 020d6434844b22c2fe611303b338eaf53397c9db Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 23 Oct 2012 16:24:05 -0700 Subject: Renamed the streaming testsuites. --- .../spark/streaming/BasicOperationsSuite.scala | 213 ++++++++++++++++++++ .../scala/spark/streaming/CheckpointSuite.scala | 2 +- .../scala/spark/streaming/DStreamBasicSuite.scala | 211 -------------------- .../scala/spark/streaming/DStreamSuiteBase.scala | 216 --------------------- .../scala/spark/streaming/DStreamWindowSuite.scala | 188 ------------------ .../test/scala/spark/streaming/TestSuiteBase.scala | 216 +++++++++++++++++++++ .../spark/streaming/WindowOperationsSuite.scala | 188 ++++++++++++++++++ 7 files changed, 618 insertions(+), 616 deletions(-) create mode 100644 streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala delete mode 100644 streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala create mode 100644 streaming/src/test/scala/spark/streaming/TestSuiteBase.scala create mode 100644 streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala new file mode 100644 index 0000000000..d0aaac0f2e --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -0,0 +1,213 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ +import scala.runtime.RichInt +import util.ManualClock + +class BasicOperationsSuite extends TestSuiteBase { + + override def framework() = "BasicOperationsSuite" + + test("map") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.map(_.toString), + input.map(_.map(_.toString)) + ) + } + + test("flatmap") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)), + input.map(_.flatMap(x => Array(x, x * 2))) + ) + } + + test("filter") { + val input = Seq(1 to 4, 5 to 8, 9 to 12) + testOperation( + input, + (r: DStream[Int]) => r.filter(x => (x % 2 == 0)), + input.map(_.filter(x => (x % 2 == 0))) + ) + } + + test("glom") { + assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") + val input = Seq(1 to 4, 5 to 8, 9 to 12) + val output = Seq( + Seq( Seq(1, 2), Seq(3, 4) ), + Seq( Seq(5, 6), Seq(7, 8) ), + Seq( Seq(9, 10), Seq(11, 12) ) + ) + val operation = (r: DStream[Int]) => r.glom().map(_.toSeq) + testOperation(input, operation, output) + } + + test("mapPartitions") { + assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") + val input = Seq(1 to 4, 5 to 8, 9 to 12) + val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23)) + val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _))) + testOperation(input, operation, output, true) + } + + test("groupByKey") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(), + Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ), + true + ) + } + + test("reduceByKey") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + true + ) + } + + test("reduce") { + testOperation( + Seq(1 to 4, 5 to 8, 9 to 12), + (s: DStream[Int]) => s.reduce(_ + _), + Seq(Seq(10), Seq(26), Seq(42)) + ) + } + + test("mapValues") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).mapValues(_ + 10), + Seq( Seq(("a", 12), ("b", 11)), Seq(("", 12)), Seq() ), + true + ) + } + + test("flatMapValues") { + testOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), + true + ) + } + + test("cogroup") { + val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() ) + val outputData = Seq( + Seq( ("a", (Seq(1, 1), Seq("x", "x"))), ("b", (Seq(1), Seq("x"))) ), + Seq( ("a", (Seq(1), Seq())), ("b", (Seq(), Seq("x"))), ("", (Seq(1), Seq("x"))) ), + Seq( ("", (Seq(1), Seq())) ), + Seq( ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + + test("join") { + val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) + val outputData = Seq( + Seq( ("a", (1, "x")), ("b", (1, "x")) ), + Seq( ("", (1, "x")) ), + Seq( ), + Seq( ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x,1)).join(s2.map(x => (x,"x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + + test("updateStateByKey") { + val inputData = + Seq( + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + val updateStateOperation = (s: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) + } + + testOperation(inputData, updateStateOperation, outputData, true) + } + + test("forgetting of RDDs - map and window operations") { + assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") + + val input = (0 until 10).map(x => Seq(x, x + 1)).toSeq + val rememberDuration = Seconds(3) + + assert(input.size === 10, "Number of inputs have changed") + + def operation(s: DStream[Int]): DStream[(Int, Int)] = { + s.map(x => (x % 10, 1)) + .window(Seconds(2), Seconds(1)) + .window(Seconds(4), Seconds(2)) + } + + val ssc = setupStreams(input, operation _) + ssc.setRememberDuration(rememberDuration) + runStreams[(Int, Int)](ssc, input.size, input.size / 2) + + val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head + val windowedStream1 = windowedStream2.dependencies.head + val mappedStream = windowedStream1.dependencies.head + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + assert(clock.time === Seconds(10).milliseconds) + + // IDEALLY + // WindowedStream2 should remember till 7 seconds: 10, 8, + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5 + // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, + + // IN THIS TEST + // WindowedStream2 should remember till 7 seconds: 10, 8, + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 + // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 + + // WindowedStream2 + assert(windowedStream2.generatedRDDs.contains(Seconds(10))) + assert(windowedStream2.generatedRDDs.contains(Seconds(8))) + assert(!windowedStream2.generatedRDDs.contains(Seconds(6))) + + // WindowedStream1 + assert(windowedStream1.generatedRDDs.contains(Seconds(10))) + assert(windowedStream1.generatedRDDs.contains(Seconds(4))) + assert(!windowedStream1.generatedRDDs.contains(Seconds(3))) + + // MappedStream + assert(mappedStream.generatedRDDs.contains(Seconds(10))) + assert(mappedStream.generatedRDDs.contains(Seconds(2))) + assert(!mappedStream.generatedRDDs.contains(Seconds(1))) + } +} diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 061b331a16..6dcedcf463 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -3,7 +3,7 @@ package spark.streaming import spark.streaming.StreamingContext._ import java.io.File -class CheckpointSuite extends DStreamSuiteBase { +class CheckpointSuite extends TestSuiteBase { override def framework() = "CheckpointSuite" diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala deleted file mode 100644 index 290a216797..0000000000 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ /dev/null @@ -1,211 +0,0 @@ -package spark.streaming - -import spark.streaming.StreamingContext._ -import scala.runtime.RichInt -import util.ManualClock - -class DStreamBasicSuite extends DStreamSuiteBase { - - test("map") { - val input = Seq(1 to 4, 5 to 8, 9 to 12) - testOperation( - input, - (r: DStream[Int]) => r.map(_.toString), - input.map(_.map(_.toString)) - ) - } - - test("flatmap") { - val input = Seq(1 to 4, 5 to 8, 9 to 12) - testOperation( - input, - (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)), - input.map(_.flatMap(x => Array(x, x * 2))) - ) - } - - test("filter") { - val input = Seq(1 to 4, 5 to 8, 9 to 12) - testOperation( - input, - (r: DStream[Int]) => r.filter(x => (x % 2 == 0)), - input.map(_.filter(x => (x % 2 == 0))) - ) - } - - test("glom") { - assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") - val input = Seq(1 to 4, 5 to 8, 9 to 12) - val output = Seq( - Seq( Seq(1, 2), Seq(3, 4) ), - Seq( Seq(5, 6), Seq(7, 8) ), - Seq( Seq(9, 10), Seq(11, 12) ) - ) - val operation = (r: DStream[Int]) => r.glom().map(_.toSeq) - testOperation(input, operation, output) - } - - test("mapPartitions") { - assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") - val input = Seq(1 to 4, 5 to 8, 9 to 12) - val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23)) - val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _))) - testOperation(input, operation, output, true) - } - - test("groupByKey") { - testOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(), - Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ), - true - ) - } - - test("reduceByKey") { - testOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), - true - ) - } - - test("reduce") { - testOperation( - Seq(1 to 4, 5 to 8, 9 to 12), - (s: DStream[Int]) => s.reduce(_ + _), - Seq(Seq(10), Seq(26), Seq(42)) - ) - } - - test("mapValues") { - testOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).mapValues(_ + 10), - Seq( Seq(("a", 12), ("b", 11)), Seq(("", 12)), Seq() ), - true - ) - } - - test("flatMapValues") { - testOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), - Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), - true - ) - } - - test("cogroup") { - val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() ) - val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() ) - val outputData = Seq( - Seq( ("a", (Seq(1, 1), Seq("x", "x"))), ("b", (Seq(1), Seq("x"))) ), - Seq( ("a", (Seq(1), Seq())), ("b", (Seq(), Seq("x"))), ("", (Seq(1), Seq("x"))) ), - Seq( ("", (Seq(1), Seq())) ), - Seq( ) - ) - val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))) - } - testOperation(inputData1, inputData2, operation, outputData, true) - } - - test("join") { - val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) - val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) - val outputData = Seq( - Seq( ("a", (1, "x")), ("b", (1, "x")) ), - Seq( ("", (1, "x")) ), - Seq( ), - Seq( ) - ) - val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).join(s2.map(x => (x,"x"))) - } - testOperation(inputData1, inputData2, operation, outputData, true) - } - - test("updateStateByKey") { - val inputData = - Seq( - Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq() - ) - - val outputData = - Seq( - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 3), ("b", 2), ("c", 1)), - Seq(("a", 4), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)) - ) - - val updateStateOperation = (s: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { - Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) - } - s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) - } - - testOperation(inputData, updateStateOperation, outputData, true) - } - - test("forgetting of RDDs - map and window operations") { - assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") - - val input = (0 until 10).map(x => Seq(x, x + 1)).toSeq - val rememberDuration = Seconds(3) - - assert(input.size === 10, "Number of inputs have changed") - - def operation(s: DStream[Int]): DStream[(Int, Int)] = { - s.map(x => (x % 10, 1)) - .window(Seconds(2), Seconds(1)) - .window(Seconds(4), Seconds(2)) - } - - val ssc = setupStreams(input, operation _) - ssc.setRememberDuration(rememberDuration) - runStreams[(Int, Int)](ssc, input.size, input.size / 2) - - val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head - val windowedStream1 = windowedStream2.dependencies.head - val mappedStream = windowedStream1.dependencies.head - - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - assert(clock.time === Seconds(10).milliseconds) - - // IDEALLY - // WindowedStream2 should remember till 7 seconds: 10, 8, - // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5 - // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, - - // IN THIS TEST - // WindowedStream2 should remember till 7 seconds: 10, 8, - // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 - // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 - - // WindowedStream2 - assert(windowedStream2.generatedRDDs.contains(Seconds(10))) - assert(windowedStream2.generatedRDDs.contains(Seconds(8))) - assert(!windowedStream2.generatedRDDs.contains(Seconds(6))) - - // WindowedStream1 - assert(windowedStream1.generatedRDDs.contains(Seconds(10))) - assert(windowedStream1.generatedRDDs.contains(Seconds(4))) - assert(!windowedStream1.generatedRDDs.contains(Seconds(3))) - - // MappedStream - assert(mappedStream.generatedRDDs.contains(Seconds(10))) - assert(mappedStream.generatedRDDs.contains(Seconds(2))) - assert(!mappedStream.generatedRDDs.contains(Seconds(1))) - } -} diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala deleted file mode 100644 index 2a4b37c965..0000000000 --- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala +++ /dev/null @@ -1,216 +0,0 @@ -package spark.streaming - -import spark.{RDD, Logging} -import util.ManualClock -import collection.mutable.ArrayBuffer -import org.scalatest.FunSuite -import collection.mutable.SynchronizedBuffer - -class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) - extends InputDStream[T](ssc_) { - var currentIndex = 0 - - def start() {} - - def stop() {} - - def compute(validTime: Time): Option[RDD[T]] = { - logInfo("Computing RDD for time " + validTime) - val rdd = if (currentIndex < input.size) { - ssc.sc.makeRDD(input(currentIndex), numPartitions) - } else { - ssc.sc.makeRDD(Seq[T](), numPartitions) - } - logInfo("Created RDD " + rdd.id) - currentIndex += 1 - Some(rdd) - } -} - -class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) - extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { - val collected = rdd.collect() - output += collected - }) - -trait DStreamSuiteBase extends FunSuite with Logging { - - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - - def framework() = "DStreamSuiteBase" - - def master() = "local[2]" - - def batchDuration() = Seconds(1) - - def checkpointFile() = null.asInstanceOf[String] - - def checkpointInterval() = batchDuration - - def numInputPartitions() = 2 - - def maxWaitTimeMillis() = 10000 - - def setupStreams[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V] - ): StreamingContext = { - - // Create StreamingContext - val ssc = new StreamingContext(master, framework) - ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) - } - - // Setup the stream computation - val inputStream = new TestInputStream(ssc, input, numInputPartitions) - val operatedStream = operation(inputStream) - val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) - ssc.registerInputStream(inputStream) - ssc.registerOutputStream(outputStream) - ssc - } - - def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest]( - input1: Seq[Seq[U]], - input2: Seq[Seq[V]], - operation: (DStream[U], DStream[V]) => DStream[W] - ): StreamingContext = { - - // Create StreamingContext - val ssc = new StreamingContext(master, framework) - ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) - } - - // Setup the stream computation - val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions) - val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) - val operatedStream = operation(inputStream1, inputStream2) - val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]]) - ssc.registerInputStream(inputStream1) - ssc.registerInputStream(inputStream2) - ssc.registerOutputStream(outputStream) - ssc - } - - - def runStreams[V: ClassManifest]( - ssc: StreamingContext, - numBatches: Int, - numExpectedOutput: Int - ): Seq[Seq[V]] = { - - assert(numBatches > 0, "Number of batches to run stream computation is zero") - assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") - logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) - - // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] - val output = outputStream.output - - try { - // Start computation - ssc.start() - - // Advance manual clock - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.time) - clock.addToTime(numBatches * batchDuration.milliseconds) - logInfo("Manual clock after advancing = " + clock.time) - - // Wait until expected number of output items have been generated - val startTime = System.currentTimeMillis() - while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) - Thread.sleep(100) - } - val timeTaken = System.currentTimeMillis() - startTime - - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") - - Thread.sleep(500) // Give some time for the forgetting old RDDs to complete - } catch { - case e: Exception => e.printStackTrace(); throw e; - } finally { - ssc.stop() - } - - output - } - - def verifyOutput[V: ClassManifest]( - output: Seq[Seq[V]], - expectedOutput: Seq[Seq[V]], - useSet: Boolean - ) { - logInfo("--------------------------------") - logInfo("output.size = " + output.size) - logInfo("output") - output.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Match the output with the expected output - assert(output.size === expectedOutput.size, "Number of outputs do not match") - for (i <- 0 until output.size) { - if (useSet) { - assert(output(i).toSet === expectedOutput(i).toSet) - } else { - assert(output(i).toList === expectedOutput(i).toList) - } - } - logInfo("Output verified successfully") - } - - def testOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - useSet: Boolean = false - ) { - testOperation[U, V](input, operation, expectedOutput, -1, useSet) - } - - def testOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - numBatches: Int, - useSet: Boolean - ) { - val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, numBatches_, expectedOutput.size) - verifyOutput[V](output, expectedOutput, useSet) - } - - def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( - input1: Seq[Seq[U]], - input2: Seq[Seq[V]], - operation: (DStream[U], DStream[V]) => DStream[W], - expectedOutput: Seq[Seq[W]], - useSet: Boolean - ) { - testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet) - } - - def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( - input1: Seq[Seq[U]], - input2: Seq[Seq[V]], - operation: (DStream[U], DStream[V]) => DStream[W], - expectedOutput: Seq[Seq[W]], - numBatches: Int, - useSet: Boolean - ) { - val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V, W](input1, input2, operation) - val output = runStreams[W](ssc, numBatches_, expectedOutput.size) - verifyOutput[W](output, expectedOutput, useSet) - } -} diff --git a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala deleted file mode 100644 index cfcab6298d..0000000000 --- a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala +++ /dev/null @@ -1,188 +0,0 @@ -package spark.streaming - -import spark.streaming.StreamingContext._ - -class DStreamWindowSuite extends DStreamSuiteBase { - - override def framework() = "DStreamWindowSuite" - - override def maxWaitTimeMillis() = 20000 - - val largerSlideInput = Seq( - Seq(("a", 1)), - Seq(("a", 2)), // 1st window from here - Seq(("a", 3)), - Seq(("a", 4)), // 2nd window from here - Seq(("a", 5)), - Seq(("a", 6)), // 3rd window from here - Seq(), - Seq() // 4th window from here - ) - - val largerSlideOutput = Seq( - Seq(("a", 3)), - Seq(("a", 10)), - Seq(("a", 18)), - Seq(("a", 11)) - ) - - - val bigInput = Seq( - Seq(("a", 1)), - Seq(("a", 1), ("b", 1)), - Seq(("a", 1), ("b", 1), ("c", 1)), - Seq(("a", 1), ("b", 1)), - Seq(("a", 1)), - Seq(), - Seq(("a", 1)), - Seq(("a", 1), ("b", 1)), - Seq(("a", 1), ("b", 1), ("c", 1)), - Seq(("a", 1), ("b", 1)), - Seq(("a", 1)), - Seq() - ) - - val bigOutput = Seq( - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 1)), - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 1)) - ) - - /* - The output of the reduceByKeyAndWindow with inverse reduce function is - difference from the naive reduceByKeyAndWindow. Even if the count of a - particular key is 0, the key does not get eliminated from the RDDs of - ReducedWindowedDStream. This causes the number of keys in these RDDs to - increase forever. A more generalized version that allows elimination of - keys should be considered. - */ - val bigOutputInv = Seq( - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 1), ("c", 0)), - Seq(("a", 1), ("b", 0), ("c", 0)), - Seq(("a", 1), ("b", 0), ("c", 0)), - Seq(("a", 2), ("b", 1), ("c", 0)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 2), ("c", 1)), - Seq(("a", 2), ("b", 1), ("c", 0)), - Seq(("a", 1), ("b", 0), ("c", 0)) - ) - - def testReduceByKeyAndWindow( - name: String, - input: Seq[Seq[(String, Int)]], - expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = batchDuration * 2, - slideTime: Time = batchDuration - ) { - test("reduceByKeyAndWindow - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt - val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - } - - def testReduceByKeyAndWindowInv( - name: String, - input: Seq[Seq[(String, Int)]], - expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = batchDuration * 2, - slideTime: Time = batchDuration - ) { - test("reduceByKeyAndWindowInv - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt - val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - } - - - // Testing naive reduceByKeyAndWindow (without invertible function) - - testReduceByKeyAndWindow( - "basic reduction", - Seq( Seq(("a", 1), ("a", 3)) ), - Seq( Seq(("a", 4)) ) - ) - - testReduceByKeyAndWindow( - "key already in window and new value added into window", - Seq( Seq(("a", 1)), Seq(("a", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2)) ) - ) - - testReduceByKeyAndWindow( - "new key added into window", - Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) - ) - - testReduceByKeyAndWindow( - "key removed from window", - Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), - Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq() ) - ) - - testReduceByKeyAndWindow( - "larger slide time", - largerSlideInput, - largerSlideOutput, - Seconds(4), - Seconds(2) - ) - - testReduceByKeyAndWindow("big test", bigInput, bigOutput) - - - // Testing reduceByKeyAndWindow (with invertible reduce function) - - testReduceByKeyAndWindowInv( - "basic reduction", - Seq(Seq(("a", 1), ("a", 3)) ), - Seq(Seq(("a", 4)) ) - ) - - testReduceByKeyAndWindowInv( - "key already in window and new value added into window", - Seq( Seq(("a", 1)), Seq(("a", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2)) ) - ) - - testReduceByKeyAndWindowInv( - "new key added into window", - Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), - Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) - ) - - testReduceByKeyAndWindowInv( - "key removed from window", - Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), - Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) - ) - - testReduceByKeyAndWindowInv( - "larger slide time", - largerSlideInput, - largerSlideOutput, - Seconds(4), - Seconds(2) - ) - - testReduceByKeyAndWindowInv("big test", bigInput, bigOutputInv) -} diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala new file mode 100644 index 0000000000..c1b7772e7b --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -0,0 +1,216 @@ +package spark.streaming + +import spark.{RDD, Logging} +import util.ManualClock +import collection.mutable.ArrayBuffer +import org.scalatest.FunSuite +import collection.mutable.SynchronizedBuffer + +class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) + extends InputDStream[T](ssc_) { + var currentIndex = 0 + + def start() {} + + def stop() {} + + def compute(validTime: Time): Option[RDD[T]] = { + logInfo("Computing RDD for time " + validTime) + val rdd = if (currentIndex < input.size) { + ssc.sc.makeRDD(input(currentIndex), numPartitions) + } else { + ssc.sc.makeRDD(Seq[T](), numPartitions) + } + logInfo("Created RDD " + rdd.id) + currentIndex += 1 + Some(rdd) + } +} + +class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) + extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + }) + +trait TestSuiteBase extends FunSuite with Logging { + + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + + def framework() = "TestSuiteBase" + + def master() = "local[2]" + + def batchDuration() = Seconds(1) + + def checkpointFile() = null.asInstanceOf[String] + + def checkpointInterval() = batchDuration + + def numInputPartitions() = 2 + + def maxWaitTimeMillis() = 10000 + + def setupStreams[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V] + ): StreamingContext = { + + // Create StreamingContext + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + if (checkpointFile != null) { + ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + } + + // Setup the stream computation + val inputStream = new TestInputStream(ssc, input, numInputPartitions) + val operatedStream = operation(inputStream) + val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) + ssc.registerInputStream(inputStream) + ssc.registerOutputStream(outputStream) + ssc + } + + def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W] + ): StreamingContext = { + + // Create StreamingContext + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + if (checkpointFile != null) { + ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + } + + // Setup the stream computation + val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions) + val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) + val operatedStream = operation(inputStream1, inputStream2) + val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]]) + ssc.registerInputStream(inputStream1) + ssc.registerInputStream(inputStream2) + ssc.registerOutputStream(outputStream) + ssc + } + + + def runStreams[V: ClassManifest]( + ssc: StreamingContext, + numBatches: Int, + numExpectedOutput: Int + ): Seq[Seq[V]] = { + + assert(numBatches > 0, "Number of batches to run stream computation is zero") + assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") + logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) + + // Get the output buffer + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + val output = outputStream.output + + try { + // Start computation + ssc.start() + + // Advance manual clock + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logInfo("Manual clock before advancing = " + clock.time) + clock.addToTime(numBatches * batchDuration.milliseconds) + logInfo("Manual clock after advancing = " + clock.time) + + // Wait until expected number of output items have been generated + val startTime = System.currentTimeMillis() + while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) + Thread.sleep(100) + } + val timeTaken = System.currentTimeMillis() - startTime + + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") + + Thread.sleep(500) // Give some time for the forgetting old RDDs to complete + } catch { + case e: Exception => e.printStackTrace(); throw e; + } finally { + ssc.stop() + } + + output + } + + def verifyOutput[V: ClassManifest]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + useSet: Boolean + ) { + logInfo("--------------------------------") + logInfo("output.size = " + output.size) + logInfo("output") + output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Match the output with the expected output + assert(output.size === expectedOutput.size, "Number of outputs do not match") + for (i <- 0 until output.size) { + if (useSet) { + assert(output(i).toSet === expectedOutput(i).toSet) + } else { + assert(output(i).toList === expectedOutput(i).toList) + } + } + logInfo("Output verified successfully") + } + + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { + testOperation[U, V](input, operation, expectedOutput, -1, useSet) + } + + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatches: Int, + useSet: Boolean + ) { + val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, numBatches_, expectedOutput.size) + verifyOutput[V](output, expectedOutput, useSet) + } + + def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W], + expectedOutput: Seq[Seq[W]], + useSet: Boolean + ) { + testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet) + } + + def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( + input1: Seq[Seq[U]], + input2: Seq[Seq[V]], + operation: (DStream[U], DStream[V]) => DStream[W], + expectedOutput: Seq[Seq[W]], + numBatches: Int, + useSet: Boolean + ) { + val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size + val ssc = setupStreams[U, V, W](input1, input2, operation) + val output = runStreams[W](ssc, numBatches_, expectedOutput.size) + verifyOutput[W](output, expectedOutput, useSet) + } +} diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala new file mode 100644 index 0000000000..90d67844bb --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -0,0 +1,188 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ + +class WindowOperationsSuite extends TestSuiteBase { + + override def framework() = "WindowOperationsSuite" + + override def maxWaitTimeMillis() = 20000 + + val largerSlideInput = Seq( + Seq(("a", 1)), + Seq(("a", 2)), // 1st window from here + Seq(("a", 3)), + Seq(("a", 4)), // 2nd window from here + Seq(("a", 5)), + Seq(("a", 6)), // 3rd window from here + Seq(), + Seq() // 4th window from here + ) + + val largerSlideOutput = Seq( + Seq(("a", 3)), + Seq(("a", 10)), + Seq(("a", 18)), + Seq(("a", 11)) + ) + + + val bigInput = Seq( + Seq(("a", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1)), + Seq(), + Seq(("a", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1), ("c", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 1)), + Seq() + ) + + val bigOutput = Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 1)), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 1)) + ) + + /* + The output of the reduceByKeyAndWindow with inverse reduce function is + difference from the naive reduceByKeyAndWindow. Even if the count of a + particular key is 0, the key does not get eliminated from the RDDs of + ReducedWindowedDStream. This causes the number of keys in these RDDs to + increase forever. A more generalized version that allows elimination of + keys should be considered. + */ + val bigOutputInv = Seq( + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 2), ("c", 1)), + Seq(("a", 2), ("b", 1), ("c", 0)), + Seq(("a", 1), ("b", 0), ("c", 0)) + ) + + def testReduceByKeyAndWindow( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowTime: Time = batchDuration * 2, + slideTime: Time = batchDuration + ) { + test("reduceByKeyAndWindow - " + name) { + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } + + def testReduceByKeyAndWindowInv( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowTime: Time = batchDuration * 2, + slideTime: Time = batchDuration + ) { + test("reduceByKeyAndWindowInv - " + name) { + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } + + + // Testing naive reduceByKeyAndWindow (without invertible function) + + testReduceByKeyAndWindow( + "basic reduction", + Seq( Seq(("a", 1), ("a", 3)) ), + Seq( Seq(("a", 4)) ) + ) + + testReduceByKeyAndWindow( + "key already in window and new value added into window", + Seq( Seq(("a", 1)), Seq(("a", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2)) ) + ) + + testReduceByKeyAndWindow( + "new key added into window", + Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) + ) + + testReduceByKeyAndWindow( + "key removed from window", + Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), + Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq() ) + ) + + testReduceByKeyAndWindow( + "larger slide time", + largerSlideInput, + largerSlideOutput, + Seconds(4), + Seconds(2) + ) + + testReduceByKeyAndWindow("big test", bigInput, bigOutput) + + + // Testing reduceByKeyAndWindow (with invertible reduce function) + + testReduceByKeyAndWindowInv( + "basic reduction", + Seq(Seq(("a", 1), ("a", 3)) ), + Seq(Seq(("a", 4)) ) + ) + + testReduceByKeyAndWindowInv( + "key already in window and new value added into window", + Seq( Seq(("a", 1)), Seq(("a", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2)) ) + ) + + testReduceByKeyAndWindowInv( + "new key added into window", + Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), + Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) + ) + + testReduceByKeyAndWindowInv( + "key removed from window", + Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), + Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) + ) + + testReduceByKeyAndWindowInv( + "larger slide time", + largerSlideInput, + largerSlideOutput, + Seconds(4), + Seconds(2) + ) + + testReduceByKeyAndWindowInv("big test", bigInput, bigOutputInv) +} -- cgit v1.2.3 From 1ef6ea25135fd33a7913944628b67f24c87db1f5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 24 Oct 2012 14:44:20 -0700 Subject: Added tests for testing network input stream. --- .../src/main/scala/spark/streaming/DStream.scala | 5 +- .../spark/streaming/NetworkInputTracker.scala | 9 +- .../scala/spark/streaming/SocketInputDStream.scala | 17 +++- .../scala/spark/streaming/StreamingContext.scala | 14 +-- .../scala/spark/streaming/InputStreamsSuite.scala | 112 +++++++++++++++++++++ 5 files changed, 134 insertions(+), 23 deletions(-) create mode 100644 streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 38bb7c8b94..4bc063719c 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -93,8 +93,9 @@ extends Serializable with Logging { * its parent DStreams. */ protected[streaming] def initialize(time: Time) { - if (zeroTime != null) { - throw new Exception("ZeroTime is already initialized, cannot initialize it again") + if (zeroTime != null && zeroTime != time) { + throw new Exception("ZeroTime is already initialized to " + zeroTime + + ", cannot initialize it again to " + time) } zeroTime = time dependencies.foreach(_.initialize(zeroTime)) diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 9b1b8813de..07ef79415d 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -108,10 +108,11 @@ class NetworkInputTracker( } def stopReceivers() { - implicit val ec = env.actorSystem.dispatcher - val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList - val futureOfList = Future.sequence(listOfFutures) - Await.result(futureOfList, timeout) + //implicit val ec = env.actorSystem.dispatcher + receiverInfo.values.foreach(_ ! StopReceiver) + //val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList + //val futureOfList = Future.sequence(listOfFutures) + //Await.result(futureOfList, timeout) } } } diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index 4dbf421687..8ff7865ca4 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -18,12 +18,12 @@ class SocketInputDStream[T: ClassManifest]( ) extends NetworkInputDStream[T](ssc_) { def createReceiver(): NetworkReceiver[T] = { - new ObjectInputReceiver(id, host, port, bytesToObjects, storageLevel) + new SocketReceiver(id, host, port, bytesToObjects, storageLevel) } } -class ObjectInputReceiver[T: ClassManifest]( +class SocketReceiver[T: ClassManifest]( streamId: Int, host: String, port: Int, @@ -120,7 +120,12 @@ class ObjectInputReceiver[T: ClassManifest]( } -object ObjectInputReceiver { +object SocketReceiver { + + /** + * This methods translates the data from an inputstream (say, from a socket) + * to '\n' delimited strings and returns an iterator to access the strings. + */ def bytesToLines(inputStream: InputStream): Iterator[String] = { val bufferedInputStream = new BufferedInputStream(inputStream) val dataInputStream = new DataInputStream(bufferedInputStream) @@ -133,7 +138,11 @@ object ObjectInputReceiver { private def getNext() { try { nextValue = dataInputStream.readLine() - println("[" + nextValue + "]") + if (nextValue != null) { + println("[" + nextValue + "]") + } else { + gotNext = false + } } catch { case eof: EOFException => finished = true diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 228f1a3616..e124b8cfa0 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -91,7 +91,7 @@ class StreamingContext ( port: Int, storageLevel: StorageLevel = StorageLevel.DISK_AND_MEMORY_2 ): DStream[String] = { - networkStream[String](hostname, port, ObjectInputReceiver.bytesToLines, storageLevel) + networkStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) } def networkStream[T: ClassManifest]( @@ -115,18 +115,6 @@ class StreamingContext ( inputStream } - /* - def createHttpTextStream(url: String): DStream[String] = { - createHttpStream(url, ObjectInputReceiver.bytesToLines) - } - - def createHttpStream[T: ClassManifest]( - url: String, - converter: (InputStream) => Iterator[T] - ): DStream[T] = { - } - */ - /** * This function creates a input stream that monitors a Hadoop-compatible * for new files and executes the necessary processing on them. diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala new file mode 100644 index 0000000000..dd872059ea --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -0,0 +1,112 @@ +package spark.streaming + +import java.net.{SocketException, Socket, ServerSocket} +import java.io.{BufferedWriter, OutputStreamWriter} +import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import util.ManualClock +import spark.storage.StorageLevel + + +class InputStreamsSuite extends TestSuiteBase { + + test("network input stream") { + val serverPort = 9999 + val server = new TestServer(9999) + server.start() + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + + val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.DISK_AND_MEMORY) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3) + val expectedOutput = input.map(_.toString) + for (i <- 0 until input.size) { + server.send(input(i).toString + "\n") + Thread.sleep(1000) + clock.addToTime(1000) + } + val startTime = System.currentTimeMillis() + while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size) + Thread.sleep(100) + } + Thread.sleep(5000) + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + + ssc.stop() + server.stop() + + assert(outputBuffer.size === expectedOutput.size) + for (i <- 0 until outputBuffer.size) { + assert(outputBuffer(i).size === 1) + assert(outputBuffer(i).head === expectedOutput(i)) + } + } +} + + +class TestServer(port: Int) { + + val queue = new ArrayBlockingQueue[String](100) + + val serverSocket = new ServerSocket(port) + + val servingThread = new Thread() { + override def run() { + try { + while(true) { + println("Accepting connections on port " + port) + val clientSocket = serverSocket.accept() + println("New connection") + try { + clientSocket.setTcpNoDelay(true) + val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) + + while(clientSocket.isConnected) { + val msg = queue.poll(100, TimeUnit.MILLISECONDS) + if (msg != null) { + outputStream.write(msg) + outputStream.flush() + println("Message '" + msg + "' sent") + } + } + } catch { + case e: SocketException => println(e) + } finally { + println("Connection closed") + if (!clientSocket.isClosed) clientSocket.close() + } + } + } catch { + case ie: InterruptedException => + + } finally { + serverSocket.close() + } + } + } + + def start() { servingThread.start() } + + def send(msg: String) { queue.add(msg) } + + def stop() { servingThread.interrupt() } +} + +object TestServer { + def main(args: Array[String]) { + val s = new TestServer(9999) + s.start() + while(true) { + Thread.sleep(1000) + s.send("hello") + } + } +} -- cgit v1.2.3 From ed71df46cddc9a4f1363b937c10bfa2a928e564c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 24 Oct 2012 16:49:40 -0700 Subject: Minor fixes. --- .../main/scala/spark/streaming/JobManager.scala | 2 +- .../spark/streaming/NetworkInputDStream.scala | 3 +- .../scala/spark/streaming/SocketInputDStream.scala | 33 ++++++++++------------ .../scala/spark/streaming/InputStreamsSuite.scala | 16 ++++++----- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 230d806a89..9bf9251519 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -12,7 +12,7 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { SparkEnv.set(ssc.env) try { val timeTaken = job.run() - println("Total delay: %.5f s for job %s (execution: %.5f s)".format( + logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format( (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0)) } catch { case e: Exception => diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index 6b41e4d2c8..5669d7fedf 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -74,7 +74,8 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ onStart() } catch { case ie: InterruptedException => - logWarning("Receiving thread interrupted") + logInfo("Receiving thread interrupted") + //println("Receiving thread interrupted") case e: Exception => stopOnError(e) } diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index 8ff7865ca4..b566200273 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -3,11 +3,12 @@ package spark.streaming import spark.streaming.util.{RecurringTimer, SystemClock} import spark.storage.StorageLevel -import java.io.{EOFException, DataInputStream, BufferedInputStream, InputStream} +import java.io._ import java.net.Socket import java.util.concurrent.ArrayBlockingQueue import scala.collection.mutable.ArrayBuffer +import scala.Serializable class SocketInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, @@ -127,8 +128,7 @@ object SocketReceiver { * to '\n' delimited strings and returns an iterator to access the strings. */ def bytesToLines(inputStream: InputStream): Iterator[String] = { - val bufferedInputStream = new BufferedInputStream(inputStream) - val dataInputStream = new DataInputStream(bufferedInputStream) + val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) val iterator = new Iterator[String] { var gotNext = false @@ -138,35 +138,32 @@ object SocketReceiver { private def getNext() { try { nextValue = dataInputStream.readLine() - if (nextValue != null) { - println("[" + nextValue + "]") - } else { - gotNext = false - } - } catch { - case eof: EOFException => + if (nextValue == null) { finished = true + } } gotNext = true } override def hasNext: Boolean = { - if (!gotNext) { - getNext() - } - if (finished) { - dataInputStream.close() + if (!finished) { + if (!gotNext) { + getNext() + if (finished) { + dataInputStream.close() + } + } } !finished } override def next(): String = { - if (!gotNext) { - getNext() - } if (finished) { throw new NoSuchElementException("End of stream") } + if (!gotNext) { + getNext() + } gotNext = false nextValue } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index dd872059ea..a3f213ebd0 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -6,6 +6,7 @@ import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import util.ManualClock import spark.storage.StorageLevel +import spark.Logging class InputStreamsSuite extends TestSuiteBase { @@ -39,9 +40,10 @@ class InputStreamsSuite extends TestSuiteBase { Thread.sleep(5000) val timeTaken = System.currentTimeMillis() - startTime assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - - ssc.stop() + logInfo("Stopping server") server.stop() + logInfo("Stopping context") + ssc.stop() assert(outputBuffer.size === expectedOutput.size) for (i <- 0 until outputBuffer.size) { @@ -52,7 +54,7 @@ class InputStreamsSuite extends TestSuiteBase { } -class TestServer(port: Int) { +class TestServer(port: Int) extends Logging { val queue = new ArrayBlockingQueue[String](100) @@ -62,9 +64,9 @@ class TestServer(port: Int) { override def run() { try { while(true) { - println("Accepting connections on port " + port) + logInfo("Accepting connections on port " + port) val clientSocket = serverSocket.accept() - println("New connection") + logInfo("New connection") try { clientSocket.setTcpNoDelay(true) val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) @@ -74,13 +76,13 @@ class TestServer(port: Int) { if (msg != null) { outputStream.write(msg) outputStream.flush() - println("Message '" + msg + "' sent") + logInfo("Message '" + msg + "' sent") } } } catch { case e: SocketException => println(e) } finally { - println("Connection closed") + logInfo("Connection closed") if (!clientSocket.isClosed) clientSocket.close() } } -- cgit v1.2.3 From 926e05b0300ad2850d48e5692d73c209c1c90100 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 24 Oct 2012 23:14:37 -0700 Subject: Added tests for the file input stream. --- .../scala/spark/streaming/FileInputDStream.scala | 4 +- .../scala/spark/streaming/InputStreamsSuite.scala | 68 ++++++++++++++++++++-- 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala index 78537b8794..69d3504c72 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -49,7 +49,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K if (!filter.accept(path)) { return false } else { - val modTime = fs.getFileStatus(path).getModificationTime() + val modTime = fs.getFileStatus(path).getModificationTime() if (modTime <= lastModTime) { return false } @@ -60,7 +60,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } } - + val newFiles = fs.listStatus(path, newFilter) logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) if (newFiles.length > 0) { diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index a3f213ebd0..fcf5d22f5c 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -1,43 +1,48 @@ package spark.streaming import java.net.{SocketException, Socket, ServerSocket} -import java.io.{BufferedWriter, OutputStreamWriter} +import java.io.{File, BufferedWriter, OutputStreamWriter} import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import util.ManualClock import spark.storage.StorageLevel import spark.Logging +import scala.util.Random +import org.apache.commons.io.FileUtils class InputStreamsSuite extends TestSuiteBase { test("network input stream") { + // Start the server val serverPort = 9999 val server = new TestServer(9999) server.start() + + // Set up the streaming context and input streams val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.DISK_AND_MEMORY) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] val outputStream = new TestOutputStream(networkStream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() + // Feed data to the server to send to the Spark Streaming network receiver val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3) + val input = Seq(1, 2, 3, 4, 5) val expectedOutput = input.map(_.toString) for (i <- 0 until input.size) { server.send(input(i).toString + "\n") - Thread.sleep(1000) - clock.addToTime(1000) + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) } val startTime = System.currentTimeMillis() while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size) Thread.sleep(100) } - Thread.sleep(5000) + Thread.sleep(1000) val timeTaken = System.currentTimeMillis() - startTime assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") logInfo("Stopping server") @@ -45,6 +50,57 @@ class InputStreamsSuite extends TestSuiteBase { logInfo("Stopping context") ssc.stop() + // Verify whether data received by Spark Streaming was as expected + assert(outputBuffer.size === expectedOutput.size) + for (i <- 0 until outputBuffer.size) { + assert(outputBuffer(i).size === 1) + assert(outputBuffer(i).head === expectedOutput(i)) + } + } + + test("file input stream") { + // Create a temporary directory + val dir = { + var temp = File.createTempFile(".temp.", Random.nextInt().toString) + temp.delete() + temp.mkdirs() + temp.deleteOnExit() + println("Created temp dir " + temp) + temp + } + + // Set up the streaming context and input streams + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + val filestream = ssc.textFileStream(dir.toString) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] + val outputStream = new TestOutputStream(filestream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Create files in the temporary directory so that Spark Streaming can read data from it + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3, 4, 5) + val expectedOutput = input.map(_.toString) + Thread.sleep(1000) + for (i <- 0 until input.size) { + FileUtils.writeStringToFile(new File(dir, i.toString), input(i).toString + "\n") + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(500) + } + val startTime = System.currentTimeMillis() + while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + println("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size) + Thread.sleep(100) + } + Thread.sleep(1000) + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + println("Stopping context") + ssc.stop() + + // Verify whether data received by Spark Streaming was as expected assert(outputBuffer.size === expectedOutput.size) for (i <- 0 until outputBuffer.size) { assert(outputBuffer(i).size === 1) -- cgit v1.2.3 From 1b900183c8bb4063d8ae7bd5134fdadd52b3a155 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 27 Oct 2012 18:55:50 -0700 Subject: Added save operations to DStreams. --- .../src/main/scala/spark/streaming/DStream.scala | 16 ++++++ .../spark/streaming/PairDStreamFunctions.scala | 61 ++++++++++++++++++++-- .../scala/spark/streaming/StreamingContext.scala | 10 ++++ 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 12d7ba97ea..175ebf104f 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -363,6 +363,22 @@ extends Serializable with Logging { rdds.toSeq } + def saveAsObjectFiles(prefix: String, suffix: String = "") { + val saveFunc = (rdd: RDD[T], time: Time) => { + val file = rddToFileName(prefix, suffix, time) + rdd.saveAsObjectFile(file) + } + this.foreachRDD(saveFunc) + } + + def saveAsTextFiles(prefix: String, suffix: String = "") { + val saveFunc = (rdd: RDD[T], time: Time) => { + val file = rddToFileName(prefix, suffix, time) + rdd.saveAsTextFile(file) + } + this.foreachRDD(saveFunc) + } + def register() { ssc.registerOutputStream(this) } diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index ce1f4ad0a0..f88247708b 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -1,9 +1,16 @@ package spark.streaming -import scala.collection.mutable.ArrayBuffer -import spark.{Manifests, RDD, Partitioner, HashPartitioner} import spark.streaming.StreamingContext._ -import javax.annotation.Nullable + +import spark.{Manifests, RDD, Partitioner, HashPartitioner} +import spark.SparkContext._ + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.mapred.{JobConf, OutputFormat} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.hadoop.mapred.OutputFormat +import org.apache.hadoop.conf.Configuration class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)]) extends Serializable { @@ -231,6 +238,54 @@ extends Serializable { for (v <- vs.iterator; w <- ws.iterator) yield (v, w) } } + + def saveAsHadoopFiles[F <: OutputFormat[K, V]]( + prefix: String, + suffix: String + )(implicit fm: ClassManifest[F]) { + saveAsHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + } + + def saveAsHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + conf: JobConf = new JobConf + ) { + val saveFunc = (rdd: RDD[(K, V)], time: Time) => { + val file = rddToFileName(prefix, suffix, time) + rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) + } + self.foreachRDD(saveFunc) + } + + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( + prefix: String, + suffix: String + )(implicit fm: ClassManifest[F]) { + saveAsNewAPIHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + } + + def saveAsNewAPIHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]], + conf: Configuration = new Configuration + ) { + val saveFunc = (rdd: RDD[(K, V)], time: Time) => { + val file = rddToFileName(prefix, suffix, time) + rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) + } + self.foreachRDD(saveFunc) + } + + private def getKeyClass() = implicitly[ClassManifest[K]].erasure + + private def getValueClass() = implicitly[ClassManifest[V]].erasure } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 7c7b3afe47..b3148eaa97 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -225,5 +225,15 @@ object StreamingContext { implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } + + def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { + if (prefix == null) { + time.millis.toString + } else if (suffix == null || suffix.length ==0) { + prefix + "-" + time.milliseconds + } else { + prefix + "-" + time.milliseconds + "." + suffix + } + } } -- cgit v1.2.3 From 7859879aaa1860ff6b383e32a18fd9a410a97416 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 28 Oct 2012 16:46:31 -0700 Subject: Bump required Py4J version and add test for large broadcast variables. --- pyspark/README | 2 +- pyspark/pyspark/broadcast.py | 2 ++ pyspark/requirements.txt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pyspark/README b/pyspark/README index 63a1def141..55490e1a83 100644 --- a/pyspark/README +++ b/pyspark/README @@ -36,7 +36,7 @@ examples. PySpark requires a development version of Py4J, a Python library for interacting with Java processes. It can be installed from https://github.com/bartdag/py4j; make sure to install a version that -contains at least the commits through 3dbf380d3d. +contains at least the commits through b7924aabe9. PySpark uses the `PYTHONPATH` environment variable to search for Python classes; Py4J should be on this path, along with any libraries used by diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py index 4cff02b36d..93876fa738 100644 --- a/pyspark/pyspark/broadcast.py +++ b/pyspark/pyspark/broadcast.py @@ -13,6 +13,8 @@ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + +>>> large_broadcast = sc.broadcast(list(range(10000))) """ # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt index 71e2bc2b89..48fa2ab105 100644 --- a/pyspark/requirements.txt +++ b/pyspark/requirements.txt @@ -3,4 +3,4 @@ # package is not at the root of the git repository. It may be possible to # install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. -# git+git://github.com/bartdag/py4j.git@3dbf380d3d2cdeb9aab394454ea74d80c4aba1ea +# git+git://github.com/bartdag/py4j.git@b7924aabe9c5e63f0a4d8bbd17019534c7ec014e -- cgit v1.2.3 From 2ccf3b665280bf5b0919e3801d028126cb070dbd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 28 Oct 2012 22:30:28 -0700 Subject: Fix PySpark hash partitioning bug. A Java array's hashCode is based on its object identify, not its elements, so this was causing serialized keys to be hashed incorrectly. This commit adds a PySpark-specific workaround and adds more tests. --- .../scala/spark/api/python/PythonPartitioner.scala | 41 ++++++++++++++++++++++ .../main/scala/spark/api/python/PythonRDD.scala | 10 +++--- pyspark/pyspark/rdd.py | 12 +++++-- 3 files changed, 54 insertions(+), 9 deletions(-) create mode 100644 core/src/main/scala/spark/api/python/PythonPartitioner.scala diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala new file mode 100644 index 0000000000..ef9f808fb2 --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -0,0 +1,41 @@ +package spark.api.python + +import spark.Partitioner + +import java.util.Arrays + +/** + * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + */ +class PythonPartitioner(override val numPartitions: Int) extends Partitioner { + + override def getPartition(key: Any): Int = { + if (key == null) { + return 0 + } + else { + val hashCode = { + if (key.isInstanceOf[Array[Byte]]) { + System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + ) + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + } + else + key.hashCode() + } + val mod = hashCode % numPartitions + if (mod < 0) { + mod + numPartitions + } else { + mod // Guard against negative hash codes + } + } + } + + override def equals(other: Any): Boolean = other match { + case h: PythonPartitioner => + h.numPartitions == numPartitions + case _ => + false + } +} diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index a593e53efd..50094d6b0f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -179,14 +179,12 @@ object PythonRDD { val dOut = new DataOutputStream(baos); if (elem.isInstanceOf[Array[Byte]]) { elem.asInstanceOf[Array[Byte]] - } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { - val t = elem.asInstanceOf[scala.Tuple2[_, _]] - val t1 = t._1.asInstanceOf[Array[Byte]] - val t2 = t._2.asInstanceOf[Array[Byte]] + } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { + val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t1)) - dOut.write(PythonRDD.stripPickle(t2)) + dOut.write(PythonRDD.stripPickle(t._1)) + dOut.write(PythonRDD.stripPickle(t._2)) dOut.writeByte(Pickle.TUPLE2) dOut.writeByte(Pickle.STOP) baos.toByteArray() diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index e4878c08ba..85a24c6854 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -310,6 +310,12 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) def partitionBy(self, numSplits, hashFunc=hash): + """ + >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) + >>> sets = pairs.partitionBy(2).glom().collect() + >>> set(sets[0]).intersection(set(sets[1])) + set([]) + """ if numSplits is None: numSplits = self.ctx.defaultParallelism def add_shuffle_key(iterator): @@ -319,7 +325,7 @@ class RDD(object): keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) @@ -391,7 +397,7 @@ class RDD(object): """ >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> x.cogroup(y).collect() + >>> sorted(x.cogroup(y).collect()) [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup(self, other, numSplits) @@ -462,7 +468,7 @@ def _test(): import doctest from pyspark.context import SparkContext globs = globals().copy() - globs['sc'] = SparkContext('local', 'PythonTest') + globs['sc'] = SparkContext('local[4]', 'PythonTest') doctest.testmod(globs=globs) globs['sc'].stop() -- cgit v1.2.3 From ac12abc17ff90ec99192f3c3de4d1d390969e635 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 29 Oct 2012 11:55:27 -0700 Subject: Modified RDD API to make dependencies a var (therefore can be changed to checkpointed hadoop rdd) and othere references to parent RDDs either through dependencies or through a weak reference (to allow finalizing when dependencies do not refer to it any more). --- core/src/main/scala/spark/PairRDDFunctions.scala | 24 ++++++++++----------- core/src/main/scala/spark/ParallelCollection.scala | 8 +++---- core/src/main/scala/spark/RDD.scala | 25 ++++++++++++++++------ core/src/main/scala/spark/SparkContext.scala | 4 ++++ core/src/main/scala/spark/rdd/CartesianRDD.scala | 21 ++++++++++++------ core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 13 +++++++---- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 19 ++++++++++------ core/src/main/scala/spark/rdd/FilteredRDD.scala | 12 +++++++---- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 10 ++++----- core/src/main/scala/spark/rdd/GlommedRDD.scala | 9 ++++---- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 +--- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 12 +++++------ .../spark/rdd/MapPartitionsWithSplitRDD.scala | 10 ++++----- core/src/main/scala/spark/rdd/MappedRDD.scala | 12 +++++------ core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 9 ++++---- core/src/main/scala/spark/rdd/PipedRDD.scala | 16 +++++++------- core/src/main/scala/spark/rdd/SampledRDD.scala | 15 ++++++------- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 13 ++++++----- core/src/main/scala/spark/rdd/UnionRDD.scala | 20 ++++++++++------- 19 files changed, 149 insertions(+), 107 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index e5bb639cfd..f52af08125 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -23,6 +23,7 @@ import spark.partial.BoundedDouble import spark.partial.PartialResult import spark.rdd._ import spark.SparkContext._ +import java.lang.ref.WeakReference /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -624,23 +625,22 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner - override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))} +class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => U) + extends RDD[(K, U)](prev.get) { + + override def splits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner + override def compute(split: Split) = firstParent[(K, V)].iterator(split).map{case (k, v) => (k, f(v))} } private[spark] -class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]) - extends RDD[(K, U)](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner +class FlatMappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) + extends RDD[(K, U)](prev.get) { + override def splits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = { - prev.iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } + firstParent[(K, V)].iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } } } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9b57ae3b4f..ad06ee9736 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -22,13 +22,13 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - sc: SparkContext, + @transient sc_ : SparkContext, @transient data: Seq[T], numSlices: Int) - extends RDD[T](sc) { + extends RDD[T](sc_, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split - // instead. + // instead. UPDATE: With the new changes to enable checkpointing, this an be done. @transient val splits_ = { @@ -41,8 +41,6 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def preferredLocations(s: Split): Seq[String] = Nil - - override val dependencies: List[Dependency[_]] = Nil } private object ParallelCollection { diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff4061..c9f3763f73 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -72,7 +72,14 @@ import SparkContext._ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ -abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable { +abstract class RDD[T: ClassManifest]( + @transient var sc: SparkContext, + @transient var dependencies_ : List[Dependency[_]] = Nil + ) extends Serializable { + + + def this(@transient oneParent: RDD[_]) = + this(oneParent.context , List(new OneToOneDependency(oneParent))) // Methods that must be implemented by subclasses: @@ -83,10 +90,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def compute(split: Split): Iterator[T] /** How this RDD depends on any parent RDDs. */ - @transient val dependencies: List[Dependency[_]] + def dependencies: List[Dependency[_]] = dependencies_ + //var dependencies: List[Dependency[_]] = dependencies_ - // Methods available on all RDDs: - /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -106,8 +112,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - - /** + + private[spark] def firstParent[U: ClassManifest] = dependencies.head.rdd.asInstanceOf[RDD[U]] + private[spark] def parent[U: ClassManifest](id: Int) = dependencies(id).rdd.asInstanceOf[RDD[U]] + + // Methods available on all RDDs: + + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -129,7 +140,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - + private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 0d37075ef3..6b957a6356 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -3,6 +3,7 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger import java.net.{URI, URLClassLoader} +import java.lang.ref.WeakReference import scala.collection.Map import scala.collection.generic.Growable @@ -695,6 +696,9 @@ object SparkContext { /** Find the JAR that contains the class of a particular object */ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) + + implicit def rddToWeakRefRDD[T: ClassManifest](rdd: RDD[T]) = new WeakReference(rdd) + } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 7c354b6b2e..c97b835630 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -4,6 +4,7 @@ import spark.NarrowDependency import spark.RDD import spark.SparkContext import spark.Split +import java.lang.ref.WeakReference private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { @@ -13,13 +14,17 @@ class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with private[spark] class CartesianRDD[T: ClassManifest, U:ClassManifest]( sc: SparkContext, - rdd1: RDD[T], - rdd2: RDD[U]) + rdd1_ : WeakReference[RDD[T]], + rdd2_ : WeakReference[RDD[U]]) extends RDD[Pair[T, U]](sc) with Serializable { - + + def rdd1 = rdd1_.get + def rdd2 = rdd2_.get + val numSplitsInRdd2 = rdd2.splits.size - + + // TODO: make this null when finishing checkpoint @transient val splits_ = { // create the cross product split @@ -31,6 +36,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def preferredLocations(split: Split) = { @@ -42,8 +48,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val currSplit = split.asInstanceOf[CartesianSplit] for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y) } - - override val dependencies = List( + + // TODO: make this null when finishing checkpoint + var deps = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -51,4 +58,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2) } ) + + override def dependencies = deps } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 50bec9e63b..af54ac2fa0 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -31,12 +31,13 @@ private[spark] class CoGroupAggregator with Serializable class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { + extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator - + + // TODO: make this null when finishing checkpoint @transient - override val dependencies = { + var deps = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) @@ -50,7 +51,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } deps.toList } - + + override def dependencies = deps + + // TODO: make this null when finishing checkpoint @transient val splits_ : Array[Split] = { val firstRdd = rdds.head @@ -68,6 +72,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) array } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override val partitioner = Some(part) diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 0967f4f5df..573acf8893 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -14,11 +14,14 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten * This transformation is useful when an RDD with many partitions gets filtered into a smaller one, * or to avoid having a large number of small tasks when processing a directory with many files. */ -class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) - extends RDD[T](prev.context) { +class CoalescedRDD[T: ClassManifest]( + @transient prev: RDD[T], // TODO: Make this a weak reference + maxPartitions: Int) + extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + // TODO: make this null when finishing checkpoint @transient val splits_ : Array[Split] = { - val prevSplits = prev.splits + val prevSplits = firstParent[T].splits if (prevSplits.length < maxPartitions) { prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } } else { @@ -30,18 +33,22 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) } } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def compute(split: Split): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { - parentSplit => prev.iterator(parentSplit) + parentSplit => firstParent[T].iterator(parentSplit) } } - val dependencies = List( - new NarrowDependency(prev) { + // TODO: make this null when finishing checkpoint + var deps = List( + new NarrowDependency(firstParent) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) } ) + + override def dependencies = deps } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index dfe9dc73f3..cc2a3acd3a 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -3,10 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] -class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).filter(f) +class FilteredRDD[T: ClassManifest]( + @transient prev: WeakReference[RDD[T]], + f: T => Boolean) + extends RDD[T](prev.get) { + + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).filter(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 3534dc8057..34bd784c13 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -3,14 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: T => TraversableOnce[U]) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).flatMap(f) + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index e30564f2da..9321e89dcd 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -3,10 +3,11 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator +class GlommedRDD[T: ClassManifest](@transient prev: WeakReference[RDD[T]]) + extends RDD[Array[T]](prev.get) { + override def splits = firstParent[T].splits + override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index bf29a1f075..a12531ea89 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -46,7 +46,7 @@ class HadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], minSplits: Int) - extends RDD[(K, V)](sc) { + extends RDD[(K, V)](sc, Nil) { // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it val confBroadcast = sc.broadcast(new SerializableWritable(conf)) @@ -115,6 +115,4 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - - override val dependencies: List[Dependency[_]] = Nil } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index a904ef62c3..bad872c430 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -3,17 +3,17 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override val partitioner = if (preservesPartitioning) prev.partitioner else None + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(prev.iterator(split)) + override def splits = firstParent[T].splits + override def compute(split: Split) = f(firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index adc541694e..d7b238b05d 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -3,6 +3,7 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -11,11 +12,10 @@ import spark.Split */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: (Int, Iterator[T]) => Iterator[U]) - extends RDD[U](prev.context) { + extends RDD[U](prev.get) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(split.index, prev.iterator(split)) + override def splits = firstParent[T].splits + override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 59bedad8ef..126c6f332b 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -3,14 +3,14 @@ package spark.rdd import spark.OneToOneDependency import spark.RDD import spark.Split +import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], f: T => U) - extends RDD[U](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).map(f) + extends RDD[U](prev.get) { + + override def splits = firstParent[T].splits + override def compute(split: Split) = firstParent[T].iterator(split).map(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 7a1a0fb87d..c12df5839e 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -23,11 +23,12 @@ class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit } class NewHadoopRDD[K, V]( - sc: SparkContext, + sc : SparkContext, inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], valueClass: Class[V], + keyClass: Class[K], + valueClass: Class[V], @transient conf: Configuration) - extends RDD[(K, V)](sc) + extends RDD[(K, V)](sc, Nil) with HadoopMapReduceUtil { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it @@ -92,6 +93,4 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } - - override val dependencies: List[Dependency[_]] = Nil } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 98ea0c92d6..d54579d6d1 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -19,18 +19,18 @@ import spark.Split * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](parent.context) { + @transient prev: RDD[T], + command: Seq[String], + envVars: Map[String, String]) + extends RDD[String](prev) { - def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map()) + def this(@transient prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command)) + def this(@transient prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - override def splits = parent.splits - - override val dependencies = List(new OneToOneDependency(parent)) + override def splits = firstParent[T].splits override def compute(split: Split): Iterator[String] = { val pb = new ProcessBuilder(command) @@ -55,7 +55,7 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - for (elem <- parent.iterator(split)) { + for (elem <- firstParent[T].iterator(split)) { out.println(elem) } out.close() diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 87a5268f27..00b521b130 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -7,6 +7,7 @@ import cern.jet.random.engine.DRand import spark.RDD import spark.OneToOneDependency import spark.Split +import java.lang.ref.WeakReference private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -14,24 +15,22 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - prev: RDD[T], + @transient prev: WeakReference[RDD[T]], withReplacement: Boolean, frac: Double, seed: Int) - extends RDD[T](prev.context) { + extends RDD[T](prev.get) { @transient val splits_ = { val rg = new Random(seed) - prev.splits.map(x => new SampledRDDSplit(x, rg.nextInt)) + firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } override def splits = splits_.asInstanceOf[Array[Split]] - override val dependencies = List(new OneToOneDependency(prev)) - override def preferredLocations(split: Split) = - prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) + firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) override def compute(splitIn: Split) = { val split = splitIn.asInstanceOf[SampledRDDSplit] @@ -39,7 +38,7 @@ class SampledRDD[T: ClassManifest]( // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. val poisson = new Poisson(frac, new DRand(split.seed)) - prev.iterator(split.prev).flatMap { element => + firstParent[T].iterator(split.prev).flatMap { element => val count = poisson.nextInt() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often @@ -49,7 +48,7 @@ class SampledRDD[T: ClassManifest]( } } else { // Sampling without replacement val rand = new Random(split.seed) - prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac)) + firstParent[T].iterator(split.prev).filter(x => (rand.nextDouble <= frac)) } } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 145e419c53..62867dab4f 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -5,6 +5,7 @@ import spark.RDD import spark.ShuffleDependency import spark.SparkEnv import spark.Split +import java.lang.ref.WeakReference private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx @@ -19,8 +20,9 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { * @tparam V the value class. */ class ShuffledRDD[K, V]( - @transient parent: RDD[(K, V)], - part: Partitioner) extends RDD[(K, V)](parent.context) { + @transient prev: WeakReference[RDD[(K, V)]], + part: Partitioner) + extends RDD[(K, V)](prev.get.context, List(new ShuffleDependency(prev.get, part))) { override val partitioner = Some(part) @@ -31,10 +33,11 @@ class ShuffledRDD[K, V]( override def preferredLocations(split: Split) = Nil - val dep = new ShuffleDependency(parent, part) - override val dependencies = List(dep) + //val dep = new ShuffleDependency(parent, part) + //override val dependencies = List(dep) override def compute(split: Split): Iterator[(K, V)] = { - SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index) + val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId + SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index f0b9225f7c..0a61a2d1f5 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -7,6 +7,7 @@ import spark.RangeDependency import spark.RDD import spark.SparkContext import spark.Split +import java.lang.ref.WeakReference private[spark] class UnionSplit[T: ClassManifest]( idx: Int, @@ -22,10 +23,10 @@ private[spark] class UnionSplit[T: ClassManifest]( class UnionRDD[T: ClassManifest]( sc: SparkContext, - @transient rdds: Seq[RDD[T]]) - extends RDD[T](sc) - with Serializable { - + @transient rdds: Seq[RDD[T]]) // TODO: Make this a weak reference + extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + + // TODO: make this null when finishing checkpoint @transient val splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) @@ -37,19 +38,22 @@ class UnionRDD[T: ClassManifest]( array } + // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ - @transient - override val dependencies = { + // TODO: make this null when finishing checkpoint + @transient var deps = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) + deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } deps.toList } - + + override def dependencies = deps + override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() override def preferredLocations(s: Split): Seq[String] = -- cgit v1.2.3 From 0dcd770fdc4d558972b635b6770ed0120280ef22 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 30 Oct 2012 16:09:37 -0700 Subject: Added checkpointing support to all RDDs, along with CheckpointSuite to test checkpointing in them. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/ParallelCollection.scala | 4 +- core/src/main/scala/spark/RDD.scala | 129 +++++++++++++++++---- core/src/main/scala/spark/SparkContext.scala | 21 ++++ core/src/main/scala/spark/rdd/BlockRDD.scala | 13 ++- core/src/main/scala/spark/rdd/CartesianRDD.scala | 38 +++--- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 19 +-- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 26 +++-- core/src/main/scala/spark/rdd/FilteredRDD.scala | 2 +- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/GlommedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 2 + .../main/scala/spark/rdd/MapPartitionsRDD.scala | 2 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 2 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 2 + core/src/main/scala/spark/rdd/PipedRDD.scala | 9 +- core/src/main/scala/spark/rdd/SampledRDD.scala | 2 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 5 - core/src/main/scala/spark/rdd/UnionRDD.scala | 32 ++--- core/src/test/scala/spark/CheckpointSuite.scala | 116 ++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 25 +++- 22 files changed, 352 insertions(+), 107 deletions(-) create mode 100644 core/src/test/scala/spark/CheckpointSuite.scala diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index f52af08125..1f82bd3ab8 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -625,7 +625,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => U) +class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U) extends RDD[(K, U)](prev.get) { override def splits = firstParent[(K, V)].splits @@ -634,7 +634,7 @@ class MappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V } private[spark] -class FlatMappedValuesRDD[K, V, U](@transient prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) +class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) extends RDD[(K, U)](prev.get) { override def splits = firstParent[(K, V)].splits diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ad06ee9736..9725017b61 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -22,10 +22,10 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - @transient sc_ : SparkContext, + @transient sc : SparkContext, @transient data: Seq[T], numSlices: Int) - extends RDD[T](sc_, Nil) { + extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. UPDATE: With the new changes to enable checkpointing, this an be done. diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index c9f3763f73..e272a0ede9 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -13,6 +13,7 @@ import scala.collection.Map import scala.collection.mutable.HashMap import scala.collection.JavaConversions.mapAsScalaMap +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text @@ -74,7 +75,7 @@ import SparkContext._ */ abstract class RDD[T: ClassManifest]( @transient var sc: SparkContext, - @transient var dependencies_ : List[Dependency[_]] = Nil + var dependencies_ : List[Dependency[_]] ) extends Serializable { @@ -91,7 +92,6 @@ abstract class RDD[T: ClassManifest]( /** How this RDD depends on any parent RDDs. */ def dependencies: List[Dependency[_]] = dependencies_ - //var dependencies: List[Dependency[_]] = dependencies_ /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -100,7 +100,13 @@ abstract class RDD[T: ClassManifest]( val partitioner: Option[Partitioner] = None /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = Nil + def preferredLocations(split: Split): Seq[String] = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + Nil + } + } /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc @@ -113,8 +119,23 @@ abstract class RDD[T: ClassManifest]( // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - private[spark] def firstParent[U: ClassManifest] = dependencies.head.rdd.asInstanceOf[RDD[U]] - private[spark] def parent[U: ClassManifest](id: Int) = dependencies(id).rdd.asInstanceOf[RDD[U]] + /** Returns the first parent RDD */ + private[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** Returns the `i` th parent RDD */ + private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + + // Variables relating to checkpointing + val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD + var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing + var isCheckpointInProgress = false // set to true when checkpointing is in progress + var isCheckpointed = false // set to true after checkpointing is completed + + var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed + var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file + var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD // Methods available on all RDDs: @@ -141,32 +162,94 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { - if (!level.useDisk && level.replication < 2) { - throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") - } - - // This is a hack. Ideally this should re-use the code used by the CacheTracker - // to generate the key. - def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index) - - persist(level) - sc.runJob(this, (iter: Iterator[T]) => {} ) - - val p = this.partitioner - - new BlockRDD[T](sc, splits.map(getSplitKey).toArray) { - override val partitioner = p + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) Checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + protected[spark] def checkpoint() { + synchronized { + if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { + // do nothing + } else if (isCheckpointable) { + shouldCheckpoint = true + } else { + throw new Exception(this + " cannot be checkpointed") + } } } - + + /** + * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job + * using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). In case this RDD is not marked for checkpointing, + * doCheckpoint() is called recursively on the parent RDDs. + */ + private[spark] def doCheckpoint() { + val startCheckpoint = synchronized { + if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) { + isCheckpointInProgress = true + true + } else { + false + } + } + + if (startCheckpoint) { + val rdd = this + val env = SparkEnv.get + + // Spawn a new thread to do the checkpoint as it takes sometime to write the RDD to file + val th = new Thread() { + override def run() { + // Save the RDD to a file, create a new HadoopRDD from it, + // and change the dependencies from the original parents to the new RDD + SparkEnv.set(env) + rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString + rdd.saveAsObjectFile(checkpointFile) + rdd.synchronized { + rdd.checkpointRDD = context.objectFile[T](checkpointFile) + rdd.checkpointRDDSplits = rdd.checkpointRDD.splits + rdd.changeDependencies(rdd.checkpointRDD) + rdd.shouldCheckpoint = false + rdd.isCheckpointInProgress = false + rdd.isCheckpointed = true + } + } + } + th.start() + } else { + // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked + dependencies.foreach(_.rdd.doCheckpoint()) + } + } + + /** + * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * (`newRDD`) created from the checkpoint file. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD)) + } + /** * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ final def iterator(split: Split): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { + if (isCheckpointed) { + // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original + checkpointRDD.iterator(checkpointRDDSplits(split.index)) + } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { compute(split) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 6b957a6356..79ceab5f4f 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -188,6 +188,8 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) + private[spark] var checkpointDir: String = null + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -519,6 +521,7 @@ class SparkContext( val start = System.nanoTime val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + rdd.doCheckpoint() result } @@ -575,6 +578,24 @@ class SparkContext( return f } + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) { + val path = new Path(dir) + val fs = path.getFileSystem(new Configuration()) + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + if (deleteOnExit) fs.deleteOnExit(path) + } + checkpointDir = dir + } + /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ def defaultParallelism: Int = taskScheduler.defaultParallelism diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index cb73976aed..f4c3f99011 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -14,7 +14,7 @@ private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) - extends RDD[T](sc) { + extends RDD[T](sc, Nil) { @transient val splits_ = (0 until blockIds.size).map(i => { @@ -41,9 +41,12 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = - locations_(split.asInstanceOf[BlockRDDSplit].blockId) - - override val dependencies: List[Dependency[_]] = Nil + override def preferredLocations(split: Split) = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + locations_(split.asInstanceOf[BlockRDDSplit].blockId) + } + } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index c97b835630..458ad38d55 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,9 +1,6 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark._ import java.lang.ref.WeakReference private[spark] @@ -14,19 +11,15 @@ class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with private[spark] class CartesianRDD[T: ClassManifest, U:ClassManifest]( sc: SparkContext, - rdd1_ : WeakReference[RDD[T]], - rdd2_ : WeakReference[RDD[U]]) - extends RDD[Pair[T, U]](sc) + var rdd1 : RDD[T], + var rdd2 : RDD[U]) + extends RDD[Pair[T, U]](sc, Nil) with Serializable { - def rdd1 = rdd1_.get - def rdd2 = rdd2_.get - val numSplitsInRdd2 = rdd2.splits.size - // TODO: make this null when finishing checkpoint @transient - val splits_ = { + var splits_ = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { @@ -36,12 +29,15 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def preferredLocations(split: Split) = { - val currSplit = split.asInstanceOf[CartesianSplit] - rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) + if (isCheckpointed) { + checkpointRDD.preferredLocations(split) + } else { + val currSplit = split.asInstanceOf[CartesianSplit] + rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) + } } override def compute(split: Split) = { @@ -49,8 +45,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y) } - // TODO: make this null when finishing checkpoint - var deps = List( + var deps_ = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -59,5 +54,12 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def dependencies = deps + override def dependencies = deps_ + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + rdd1 = null + rdd2 = null + } } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index af54ac2fa0..a313ebcbe8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -30,14 +30,13 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) +class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator - // TODO: make this null when finishing checkpoint @transient - var deps = { + var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) @@ -52,11 +51,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - override def dependencies = deps + override def dependencies = deps_ - // TODO: make this null when finishing checkpoint @transient - val splits_ : Array[Split] = { + var splits_ : Array[Split] = { val firstRdd = rdds.head val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { @@ -72,13 +70,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) array } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override val partitioner = Some(part) - override def preferredLocations(s: Split) = Nil - override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size @@ -106,4 +101,10 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } map.iterator } + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + rdds = null + } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 573acf8893..5b5f72ddeb 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,8 +1,7 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.Split +import spark._ +import java.lang.ref.WeakReference private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split @@ -15,13 +14,12 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten * or to avoid having a large number of small tasks when processing a directory with many files. */ class CoalescedRDD[T: ClassManifest]( - @transient prev: RDD[T], // TODO: Make this a weak reference + var prev: RDD[T], maxPartitions: Int) extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - // TODO: make this null when finishing checkpoint - @transient val splits_ : Array[Split] = { - val prevSplits = firstParent[T].splits + @transient var splits_ : Array[Split] = { + val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } } else { @@ -33,7 +31,6 @@ class CoalescedRDD[T: ClassManifest]( } } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ override def compute(split: Split): Iterator[T] = { @@ -42,13 +39,18 @@ class CoalescedRDD[T: ClassManifest]( } } - // TODO: make this null when finishing checkpoint - var deps = List( - new NarrowDependency(firstParent) { + var deps_ : List[Dependency[_]] = List( + new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) } ) - override def dependencies = deps + override def dependencies = deps_ + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD)) + splits_ = newRDD.splits + prev = null + } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index cc2a3acd3a..1370cf6faf 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class FilteredRDD[T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: T => Boolean) extends RDD[T](prev.get) { diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 34bd784c13..6b2cc67568 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: T => TraversableOnce[U]) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index 9321e89dcd..0f0b6ab0ff 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -6,7 +6,7 @@ import spark.Split import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](@transient prev: WeakReference[RDD[T]]) +class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]]) extends RDD[Array[T]](prev.get) { override def splits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index a12531ea89..19ed56d9c0 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -115,4 +115,6 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } + + override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index bad872c430..b04f56cfcc 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index d7b238b05d..7a4b6ffb03 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -12,7 +12,7 @@ import java.lang.ref.WeakReference */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: (Int, Iterator[T]) => Iterator[U]) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 126c6f332b..8fa1872e0a 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -7,7 +7,7 @@ import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], f: T => U) extends RDD[U](prev.get) { diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index c12df5839e..2875abb2db 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -93,4 +93,6 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } + + override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d54579d6d1..d9293a9d1a 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -12,6 +12,7 @@ import spark.OneToOneDependency import spark.RDD import spark.SparkEnv import spark.Split +import java.lang.ref.WeakReference /** @@ -19,16 +20,16 @@ import spark.Split * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - @transient prev: RDD[T], + prev: WeakReference[RDD[T]], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](prev) { + extends RDD[String](prev.get) { - def this(@transient prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) + def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(@transient prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) + def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command)) override def splits = firstParent[T].splits diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 00b521b130..f273f257f8 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -15,7 +15,7 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - @transient prev: WeakReference[RDD[T]], + prev: WeakReference[RDD[T]], withReplacement: Boolean, frac: Double, seed: Int) diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 62867dab4f..b7d843c26d 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -31,11 +31,6 @@ class ShuffledRDD[K, V]( override def splits = splits_ - override def preferredLocations(split: Split) = Nil - - //val dep = new ShuffleDependency(parent, part) - //override val dependencies = List(dep) - override def compute(split: Split): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 0a61a2d1f5..643a174160 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -2,11 +2,7 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer -import spark.Dependency -import spark.RangeDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark._ import java.lang.ref.WeakReference private[spark] class UnionSplit[T: ClassManifest]( @@ -23,12 +19,11 @@ private[spark] class UnionSplit[T: ClassManifest]( class UnionRDD[T: ClassManifest]( sc: SparkContext, - @transient rdds: Seq[RDD[T]]) // TODO: Make this a weak reference + @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - // TODO: make this null when finishing checkpoint @transient - val splits_ : Array[Split] = { + var splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { @@ -38,11 +33,9 @@ class UnionRDD[T: ClassManifest]( array } - // TODO: make this return checkpoint Hadoop RDDs split when checkpointed override def splits = splits_ - // TODO: make this null when finishing checkpoint - @transient var deps = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { @@ -52,10 +45,21 @@ class UnionRDD[T: ClassManifest]( deps.toList } - override def dependencies = deps + override def dependencies = deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = - s.asInstanceOf[UnionSplit[T]].preferredLocations() + override def preferredLocations(s: Split): Seq[String] = { + if (isCheckpointed) { + checkpointRDD.preferredLocations(s) + } else { + s.asInstanceOf[UnionSplit[T]].preferredLocations() + } + } + + override protected def changeDependencies(newRDD: RDD[_]) { + deps_ = List(new OneToOneDependency(newRDD)) + splits_ = newRDD.splits + rdds = null + } } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala new file mode 100644 index 0000000000..0e5ca7dc21 --- /dev/null +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -0,0 +1,116 @@ +package spark + +import org.scalatest.{BeforeAndAfter, FunSuite} +import java.io.File +import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} +import spark.SparkContext._ +import storage.StorageLevel + +class CheckpointSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + var checkpointDir: File = _ + + before { + checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + + sc = new SparkContext("local", "test") + sc.setCheckpointDir(checkpointDir.toString) + } + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + + if (checkpointDir != null) { + checkpointDir.delete() + } + } + + test("ParallelCollection") { + val parCollection = sc.makeRDD(1 to 4) + parCollection.checkpoint() + assert(parCollection.dependencies === Nil) + val result = parCollection.collect() + sleep(parCollection) // slightly extra time as loading classes for the first can take some time + assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result) + assert(parCollection.dependencies != Nil) + assert(parCollection.collect() === result) + } + + test("BlockRDD") { + val blockId = "id" + val blockManager = SparkEnv.get.blockManager + blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) + val blockRDD = new BlockRDD[String](sc, Array(blockId)) + blockRDD.checkpoint() + val result = blockRDD.collect() + sleep(blockRDD) + assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result) + assert(blockRDD.dependencies != Nil) + assert(blockRDD.collect() === result) + } + + test("RDDs with one-to-one dependencies") { + testCheckpointing(_.map(x => x.toString)) + testCheckpointing(_.flatMap(x => 1 to x)) + testCheckpointing(_.filter(_ % 2 == 0)) + testCheckpointing(_.sample(false, 0.5, 0)) + testCheckpointing(_.glom()) + testCheckpointing(_.mapPartitions(_.map(_.toString))) + testCheckpointing(r => new MapPartitionsWithSplitRDD(r, + (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000) + testCheckpointing(_.pipe(Seq("cat"))) + } + + test("ShuffledRDD") { + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _)) + } + + test("UnionRDD") { + testCheckpointing(_.union(sc.makeRDD(5 to 6, 4))) + } + + test("CartesianRDD") { + testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000) + } + + test("CoalescedRDD") { + testCheckpointing(new CoalescedRDD(_, 2)) + } + + test("CoGroupedRDD") { + val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) + testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) + testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + } + + def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { + val parCollection = sc.makeRDD(1 to 4, 4) + val operatedRDD = op(parCollection) + operatedRDD.checkpoint() + val parentRDD = operatedRDD.dependencies.head.rdd + val result = operatedRDD.collect() + sleep(operatedRDD) + //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) + assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result) + assert(operatedRDD.dependencies.head.rdd != parentRDD) + assert(operatedRDD.collect() === result) + } + + def sleep(rdd: RDD[_]) { + val startTime = System.currentTimeMillis() + val maxWaitTime = 5000 + while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) { + Thread.sleep(50) + } + assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms") + } +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 37a0ff0947..8ac7c8451a 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -19,7 +19,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port") } - + test("basic operations") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) @@ -70,10 +70,23 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } - test("checkpointing") { + test("basic checkpointing") { + import java.io.File + val checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint() - assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + sc.setCheckpointDir(checkpointDir.toString) + val parCollection = sc.makeRDD(1 to 4) + val flatMappedRDD = parCollection.flatMap(x => 1 to x) + flatMappedRDD.checkpoint() + assert(flatMappedRDD.dependencies.head.rdd == parCollection) + val result = flatMappedRDD.collect() + Thread.sleep(1000) + assert(flatMappedRDD.dependencies.head.rdd != parCollection) + assert(flatMappedRDD.collect() === result) + + checkpointDir.deleteOnExit() } test("basic caching") { @@ -94,8 +107,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter { List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.getParents(0).toList === List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.getParents(1).toList === List(5, 6, 7, 8, 9)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9)) val coalesced2 = new CoalescedRDD(data, 3) assert(coalesced2.collect().toList === (1 to 10).toList) -- cgit v1.2.3 From 34e569f40e184a6a4f21e9d79b0e8979d8f9541f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 31 Oct 2012 00:56:40 -0700 Subject: Added 'synchronized' to RDD serialization to ensure checkpoint-related changes are reflected atomically in the task closure. Added to tests to ensure that jobs running on an RDD on which checkpointing is in progress does hurt the result of the job. --- core/src/main/scala/spark/RDD.scala | 18 ++++++- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 7 ++- core/src/test/scala/spark/CheckpointSuite.scala | 71 ++++++++++++++++++++++++- 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e272a0ede9..7b59a6f09e 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,8 +1,7 @@ package spark -import java.io.EOFException +import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream} import java.net.URL -import java.io.ObjectInputStream import java.util.concurrent.atomic.AtomicLong import java.util.Random import java.util.Date @@ -589,4 +588,19 @@ abstract class RDD[T: ClassManifest]( private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + synchronized { + oos.defaultWriteObject() + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + synchronized { + ois.defaultReadObject() + } + } + } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index b7d843c26d..31774585f4 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -27,7 +27,7 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) @transient - val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def splits = splits_ @@ -35,4 +35,9 @@ class ShuffledRDD[K, V]( val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = Nil + splits_ = newRDD.splits + } } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 0e5ca7dc21..57dc43ddac 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -5,8 +5,10 @@ import java.io.File import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} import spark.SparkContext._ import storage.StorageLevel +import java.util.concurrent.Semaphore -class CheckpointSuite extends FunSuite with BeforeAndAfter { +class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { + initLogging() var sc: SparkContext = _ var checkpointDir: File = _ @@ -92,6 +94,35 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter { testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) } + /** + * This test forces two ResultTasks of the same job to be launched before and after + * the checkpointing of job's RDD is completed. + */ + test("Threading - ResultTasks") { + val op1 = (parCollection: RDD[Int]) => { + parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) + } + val op2 = (firstRDD: RDD[(Int, Int)]) => { + firstRDD.map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) + } + testThreading(op1, op2) + } + + /** + * This test forces two ShuffleMapTasks of the same job to be launched before and after + * the checkpointing of job's RDD is completed. + */ + test("Threading - ShuffleMapTasks") { + val op1 = (parCollection: RDD[Int]) => { + parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) + } + val op2 = (firstRDD: RDD[(Int, Int)]) => { + firstRDD.groupByKey(2).map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) + } + testThreading(op1, op2) + } + + def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { val parCollection = sc.makeRDD(1 to 4, 4) val operatedRDD = op(parCollection) @@ -105,6 +136,44 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter { assert(operatedRDD.collect() === result) } + def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { + + val parCollection = sc.makeRDD(1 to 2, 2) + + // This is the RDD that is to be checkpointed + val firstRDD = op1(parCollection) + val parentRDD = firstRDD.dependencies.head.rdd + firstRDD.checkpoint() + + // This the RDD that uses firstRDD. This is designed to launch a + // ShuffleMapTask that uses firstRDD. + val secondRDD = op2(firstRDD) + + // Starting first job, to initiate the checkpointing + logInfo("\nLaunching 1st job to initiate checkpointing\n") + firstRDD.collect() + + // Checkpointing has started but not completed yet + Thread.sleep(100) + assert(firstRDD.dependencies.head.rdd === parentRDD) + + // Starting second job; first task of this job will be + // launched _before_ firstRDD is marked as checkpointed + // and the second task will be launched _after_ firstRDD + // is marked as checkpointed + logInfo("\nLaunching 2nd job that is designed to launch tasks " + + "before and after checkpointing is complete\n") + val result = secondRDD.collect() + + // Check whether firstRDD has been successfully checkpointed + assert(firstRDD.dependencies.head.rdd != parentRDD) + + logInfo("\nRecomputing 2nd job to verify the results of the previous computation\n") + // Check whether the result in the previous job was correct or not + val correctResult = secondRDD.collect() + assert(result === correctResult) + } + def sleep(rdd: RDD[_]) { val startTime = System.currentTimeMillis() val maxWaitTime = 5000 -- cgit v1.2.3 From 3fb5c9ee24302edf02df130bd0dfd0463cf6c0a4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 2 Nov 2012 12:12:25 -0700 Subject: Fixed serialization bug in countByWindow, added countByKey and countByKeyAndWindow, and added testcases for them. --- .../src/main/scala/spark/streaming/DStream.scala | 4 +- .../spark/streaming/PairDStreamFunctions.scala | 23 ++- .../scala/spark/streaming/InputStreamsSuite.scala | 18 ++ .../test/scala/spark/streaming/TestSuiteBase.scala | 8 +- .../spark/streaming/WindowOperationsSuite.scala | 181 ++++++++++++++++----- 5 files changed, 186 insertions(+), 48 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 175ebf104f..a4921bb1a2 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -338,9 +338,7 @@ extends Serializable with Logging { } def countByWindow(windowTime: Time, slideTime: Time): DStream[Int] = { - def add(v1: Int, v2: Int) = (v1 + v2) - def subtract(v1: Int, v2: Int) = (v1 - v2) - this.map(_ => 1).reduceByWindow(add _, subtract _, windowTime, slideTime) + this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowTime, slideTime) } def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index f88247708b..e09d27d34f 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -53,14 +53,18 @@ extends Serializable { combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) } - private def combineByKey[C: ClassManifest]( + def combineByKey[C: ClassManifest]( createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, - partitioner: Partitioner) : ShuffledDStream[K, V, C] = { + partitioner: Partitioner) : DStream[(K, C)] = { new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner) } + def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = { + self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) + } + def groupByKeyAndWindow(windowTime: Time, slideTime: Time): DStream[(K, Seq[V])] = { groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner()) } @@ -157,6 +161,21 @@ extends Serializable { self, cleanedReduceFunc, cleanedInvReduceFunc, windowTime, slideTime, partitioner) } + def countByKeyAndWindow( + windowTime: Time, + slideTime: Time, + numPartitions: Int = self.ssc.sc.defaultParallelism + ): DStream[(K, Long)] = { + + self.map(x => (x._1, 1L)).reduceByKeyAndWindow( + (x: Long, y: Long) => x + y, + (x: Long, y: Long) => x - y, + windowTime, + slideTime, + numPartitions + ) + } + // TODO: // // diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 6f6b18a790..c17254b809 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -51,6 +51,15 @@ class InputStreamsSuite extends TestSuiteBase { ssc.stop() // Verify whether data received by Spark Streaming was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + assert(outputBuffer.size === expectedOutput.size) for (i <- 0 until outputBuffer.size) { assert(outputBuffer(i).size === 1) @@ -101,6 +110,15 @@ class InputStreamsSuite extends TestSuiteBase { ssc.stop() // Verify whether data received by Spark Streaming was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + assert(outputBuffer.size === expectedOutput.size) for (i <- 0 until outputBuffer.size) { assert(outputBuffer(i).size === 1) diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index c1b7772e7b..c9bc454f91 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -16,13 +16,14 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ def compute(validTime: Time): Option[RDD[T]] = { logInfo("Computing RDD for time " + validTime) - val rdd = if (currentIndex < input.size) { - ssc.sc.makeRDD(input(currentIndex), numPartitions) + val index = ((validTime - zeroTime) / slideTime - 1).toInt + val rdd = if (index < input.size) { + ssc.sc.makeRDD(input(index), numPartitions) } else { ssc.sc.makeRDD(Seq[T](), numPartitions) } logInfo("Created RDD " + rdd.id) - currentIndex += 1 + //currentIndex += 1 Some(rdd) } } @@ -96,7 +97,6 @@ trait TestSuiteBase extends FunSuite with Logging { ssc } - def runStreams[V: ClassManifest]( ssc: StreamingContext, numBatches: Int, diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 90d67844bb..d7d8d5bd36 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -1,6 +1,7 @@ package spark.streaming import spark.streaming.StreamingContext._ +import collection.mutable.ArrayBuffer class WindowOperationsSuite extends TestSuiteBase { @@ -8,6 +9,8 @@ class WindowOperationsSuite extends TestSuiteBase { override def maxWaitTimeMillis() = 20000 + override def batchDuration() = Seconds(1) + val largerSlideInput = Seq( Seq(("a", 1)), Seq(("a", 2)), // 1st window from here @@ -19,7 +22,7 @@ class WindowOperationsSuite extends TestSuiteBase { Seq() // 4th window from here ) - val largerSlideOutput = Seq( + val largerSlideReduceOutput = Seq( Seq(("a", 3)), Seq(("a", 10)), Seq(("a", 18)), @@ -42,7 +45,23 @@ class WindowOperationsSuite extends TestSuiteBase { Seq() ) - val bigOutput = Seq( + val bigGroupByOutput = Seq( + Seq(("a", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1))), + Seq(("a", Seq(1))), + Seq(("a", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))), + Seq(("a", Seq(1, 1)), ("b", Seq(1))), + Seq(("a", Seq(1))) + ) + + + val bigReduceOutput = Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)), Seq(("a", 2), ("b", 2), ("c", 1)), @@ -59,13 +78,14 @@ class WindowOperationsSuite extends TestSuiteBase { /* The output of the reduceByKeyAndWindow with inverse reduce function is - difference from the naive reduceByKeyAndWindow. Even if the count of a + different from the naive reduceByKeyAndWindow. Even if the count of a particular key is 0, the key does not get eliminated from the RDDs of ReducedWindowedDStream. This causes the number of keys in these RDDs to increase forever. A more generalized version that allows elimination of keys should be considered. */ - val bigOutputInv = Seq( + + val bigReduceInvOutput = Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)), Seq(("a", 2), ("b", 2), ("c", 1)), @@ -80,38 +100,37 @@ class WindowOperationsSuite extends TestSuiteBase { Seq(("a", 1), ("b", 0), ("c", 0)) ) - def testReduceByKeyAndWindow( - name: String, - input: Seq[Seq[(String, Int)]], - expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = batchDuration * 2, - slideTime: Time = batchDuration - ) { - test("reduceByKeyAndWindow - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt - val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - } + // Testing window operation - def testReduceByKeyAndWindowInv( - name: String, - input: Seq[Seq[(String, Int)]], - expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = batchDuration * 2, - slideTime: Time = batchDuration - ) { - test("reduceByKeyAndWindowInv - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt - val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() - } - testOperation(input, operation, expectedOutput, numBatches, true) - } - } + testWindow( + "basic window", + Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), + Seq( Seq(0), Seq(0, 1), Seq(1, 2), Seq(2, 3), Seq(3, 4), Seq(4, 5)) + ) + testWindow( + "tumbling window", + Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), + Seq( Seq(0, 1), Seq(2, 3), Seq(4, 5)), + Seconds(2), + Seconds(2) + ) + + testWindow( + "larger window", + Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), + Seq( Seq(0, 1), Seq(0, 1, 2, 3), Seq(2, 3, 4, 5), Seq(4, 5)), + Seconds(4), + Seconds(2) + ) + + testWindow( + "non-overlapping window", + Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)), + Seq( Seq(1, 2), Seq(4, 5)), + Seconds(2), + Seconds(3) + ) // Testing naive reduceByKeyAndWindow (without invertible function) @@ -142,13 +161,12 @@ class WindowOperationsSuite extends TestSuiteBase { testReduceByKeyAndWindow( "larger slide time", largerSlideInput, - largerSlideOutput, + largerSlideReduceOutput, Seconds(4), Seconds(2) ) - testReduceByKeyAndWindow("big test", bigInput, bigOutput) - + testReduceByKeyAndWindow("big test", bigInput, bigReduceOutput) // Testing reduceByKeyAndWindow (with invertible reduce function) @@ -179,10 +197,95 @@ class WindowOperationsSuite extends TestSuiteBase { testReduceByKeyAndWindowInv( "larger slide time", largerSlideInput, - largerSlideOutput, + largerSlideReduceOutput, Seconds(4), Seconds(2) ) - testReduceByKeyAndWindowInv("big test", bigInput, bigOutputInv) + testReduceByKeyAndWindowInv("big test", bigInput, bigReduceInvOutput) + + test("groupByKeyAndWindow") { + val input = bigInput + val expectedOutput = bigGroupByOutput.map(_.map(x => (x._1, x._2.toSet))) + val windowTime = Seconds(2) + val slideTime = Seconds(1) + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.groupByKeyAndWindow(windowTime, slideTime) + .map(x => (x._1, x._2.toSet)) + .persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + + test("countByWindow") { + val input = Seq(Seq(1), Seq(1), Seq(1, 2), Seq(0), Seq(), Seq() ) + val expectedOutput = Seq( Seq(1), Seq(2), Seq(3), Seq(3), Seq(1), Seq(0)) + val windowTime = Seconds(2) + val slideTime = Seconds(1) + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[Int]) => s.countByWindow(windowTime, slideTime) + testOperation(input, operation, expectedOutput, numBatches, true) + } + + test("countByKeyAndWindow") { + val input = Seq(Seq(("a", 1)), Seq(("b", 1), ("b", 2)), Seq(("a", 10), ("b", 20))) + val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3))) + val windowTime = Seconds(2) + val slideTime = Seconds(1) + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.countByKeyAndWindow(windowTime, slideTime).map(x => (x._1, x._2.toInt)) + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + + + // Helper functions + + def testWindow( + name: String, + input: Seq[Seq[Int]], + expectedOutput: Seq[Seq[Int]], + windowTime: Time = Seconds(2), + slideTime: Time = Seconds(1) + ) { + test("window - " + name) { + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[Int]) => s.window(windowTime, slideTime) + testOperation(input, operation, expectedOutput, numBatches, true) + } + } + + def testReduceByKeyAndWindow( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowTime: Time = Seconds(2), + slideTime: Time = Seconds(1) + ) { + test("reduceByKeyAndWindow - " + name) { + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } + + def testReduceByKeyAndWindowInv( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowTime: Time = Seconds(2), + slideTime: Time = Seconds(1) + ) { + test("reduceByKeyAndWindowInv - " + name) { + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } } -- cgit v1.2.3 From d1542387891018914fdd6b647f17f0b05acdd40e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 4 Nov 2012 12:12:06 -0800 Subject: Made checkpointing of dstream graph to work with checkpointing of RDDs. For streams requiring checkpointing of its RDD, the default checkpoint interval is set to 10 seconds. --- core/src/main/scala/spark/RDD.scala | 32 +++-- core/src/main/scala/spark/SparkContext.scala | 13 +- .../main/scala/spark/streaming/Checkpoint.scala | 83 ++++++----- .../src/main/scala/spark/streaming/DStream.scala | 156 ++++++++++++++++----- .../main/scala/spark/streaming/DStreamGraph.scala | 7 +- .../spark/streaming/ReducedWindowedDStream.scala | 36 +++-- .../main/scala/spark/streaming/StateDStream.scala | 45 +----- .../scala/spark/streaming/StreamingContext.scala | 38 +++-- .../src/main/scala/spark/streaming/Time.scala | 4 + .../examples/FileStreamWithCheckpoint.scala | 10 +- .../streaming/examples/TopKWordCountRaw.scala | 5 +- .../spark/streaming/examples/WordCount2.scala | 7 +- .../spark/streaming/examples/WordCountRaw.scala | 6 +- .../scala/spark/streaming/examples/WordMax2.scala | 10 +- .../scala/spark/streaming/CheckpointSuite.scala | 77 +++++++--- .../test/scala/spark/streaming/TestSuiteBase.scala | 37 +++-- .../spark/streaming/WindowOperationsSuite.scala | 4 +- 17 files changed, 367 insertions(+), 203 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7b59a6f09e..63048d5df0 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -119,22 +119,23 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Returns the first parent RDD */ - private[spark] def firstParent[U: ClassManifest] = { + protected[spark] def firstParent[U: ClassManifest] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } /** Returns the `i` th parent RDD */ - private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] // Variables relating to checkpointing - val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing - var isCheckpointInProgress = false // set to true when checkpointing is in progress - var isCheckpointed = false // set to true after checkpointing is completed + protected val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed - var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file - var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD + protected var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing + protected var isCheckpointInProgress = false // set to true when checkpointing is in progress + protected[spark] var isCheckpointed = false // set to true after checkpointing is completed + + protected[spark] var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed + protected var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file + protected var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD // Methods available on all RDDs: @@ -176,6 +177,9 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { // do nothing } else if (isCheckpointable) { + if (sc.checkpointDir == null) { + throw new Exception("Checkpoint directory has not been set in the SparkContext.") + } shouldCheckpoint = true } else { throw new Exception(this + " cannot be checkpointed") @@ -183,6 +187,16 @@ abstract class RDD[T: ClassManifest]( } } + def getCheckpointData(): Any = { + synchronized { + if (isCheckpointed) { + checkpointFile + } else { + null + } + } + } + /** * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job * using this RDD has completed (therefore the RDD has been materialized and diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 79ceab5f4f..d7326971a9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -584,14 +584,15 @@ class SparkContext( * overwriting existing files may be overwritten). The directory will be deleted on exit * if indicated. */ - def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) { + def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) val fs = path.getFileSystem(new Configuration()) - if (fs.exists(path)) { - throw new Exception("Checkpoint directory '" + path + "' already exists.") - } else { - fs.mkdirs(path) - if (deleteOnExit) fs.deleteOnExit(path) + if (!useExisting) { + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + } } checkpointDir = dir } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 83a43d15cb..cf04c7031e 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -1,6 +1,6 @@ package spark.streaming -import spark.Utils +import spark.{Logging, Utils} import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration @@ -8,13 +8,14 @@ import org.apache.hadoop.conf.Configuration import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable { +class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) + extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.jobName val sparkHome = ssc.sc.sparkHome val jars = ssc.sc.jars val graph = ssc.graph - val checkpointFile = ssc.checkpointFile + val checkpointDir = ssc.checkpointDir val checkpointInterval = ssc.checkpointInterval validate() @@ -24,22 +25,25 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext assert(framework != null, "Checkpoint.framework is null") assert(graph != null, "Checkpoint.graph is null") assert(checkpointTime != null, "Checkpoint.checkpointTime is null") + logInfo("Checkpoint for time " + checkpointTime + " validated") } - def saveToFile(file: String = checkpointFile) { - val path = new Path(file) + def save(path: String) { + val file = new Path(path, "graph") val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (fs.exists(path)) { - val bkPath = new Path(path.getParent, path.getName + ".bk") - FileUtil.copy(fs, path, fs, bkPath, true, true, conf) - //logInfo("Moved existing checkpoint file to " + bkPath) + val fs = file.getFileSystem(conf) + logDebug("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) } - val fos = fs.create(path) + val fos = fs.create(file) val oos = new ObjectOutputStream(fos) oos.writeObject(this) oos.close() fs.close() + logInfo("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") } def toBytes(): Array[Byte] = { @@ -50,30 +54,41 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext object Checkpoint { - def loadFromFile(file: String): Checkpoint = { - try { - val path = new Path(file) - val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (!fs.exists(path)) { - throw new Exception("Checkpoint file '" + file + "' does not exist") + def load(path: String): Checkpoint = { + + val fs = new Path(path).getFileSystem(new Configuration()) + val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk")) + var lastException: Exception = null + var lastExceptionFile: String = null + + attempts.foreach(file => { + if (fs.exists(file)) { + try { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + println("Checkpoint successfully loaded from file " + file) + return cp + } catch { + case e: Exception => + lastException = e + lastExceptionFile = file.toString + } } - val fis = fs.open(path) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() - cp - } catch { - case e: Exception => - e.printStackTrace() - throw new Exception("Could not load checkpoint file '" + file + "'", e) + }) + + if (lastException == null) { + throw new Exception("Could not load checkpoint from path '" + path + "'") + } else { + throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException) } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index a4921bb1a2..de51c5d34a 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -13,6 +13,7 @@ import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import scala.Some +import collection.mutable abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -41,53 +42,55 @@ extends Serializable with Logging { */ // RDDs generated, marked as protected[streaming] so that testsuites can access it - protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] () + protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream - protected var zeroTime: Time = null + protected[streaming] var zeroTime: Time = null // Duration for which the DStream will remember each RDD created - protected var rememberDuration: Time = null + protected[streaming] var rememberDuration: Time = null // Storage level of the RDDs in the stream - protected var storageLevel: StorageLevel = StorageLevel.NONE + protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE - // Checkpoint level and checkpoint interval - protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - protected var checkpointInterval: Time = null + // Checkpoint details + protected[streaming] val mustCheckpoint = false + protected[streaming] var checkpointInterval: Time = null + protected[streaming] val checkpointData = new HashMap[Time, Any]() // Reference to whole DStream graph - protected var graph: DStreamGraph = null + protected[streaming] var graph: DStreamGraph = null def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created def parentRememberDuration = rememberDuration - // Change this RDD's storage level - def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[T] = { - if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { - // TODO: not sure this is necessary for DStreams + // Set caching level for the RDDs created by this DStream + def persist(level: StorageLevel): DStream[T] = { + if (this.isInitialized) { throw new UnsupportedOperationException( - "Cannot change storage level of an DStream after it was already assigned a level") + "Cannot change storage level of an DStream after streaming context has started") } - this.storageLevel = storageLevel - this.checkpointLevel = checkpointLevel - this.checkpointInterval = checkpointInterval + this.storageLevel = level this } - // Set caching level for the RDDs created by this DStream - def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null) - def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY) // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() + def checkpoint(interval: Time): DStream[T] = { + if (isInitialized) { + throw new UnsupportedOperationException( + "Cannot change checkpoint interval of an DStream after streaming context has started") + } + persist() + checkpointInterval = interval + this + } + /** * This method initializes the DStream by setting the "zero" time, based on which * the validity of future times is calculated. This method also recursively initializes @@ -99,7 +102,67 @@ extends Serializable with Logging { + ", cannot initialize it again to " + time) } zeroTime = time + + // Set the checkpoint interval to be slideTime or 10 seconds, which ever is larger + if (mustCheckpoint && checkpointInterval == null) { + checkpointInterval = slideTime.max(Seconds(10)) + logInfo("Checkpoint interval automatically set to " + checkpointInterval) + } + + // Set the minimum value of the rememberDuration if not already set + var minRememberDuration = slideTime + if (checkpointInterval != null && minRememberDuration <= checkpointInterval) { + minRememberDuration = checkpointInterval + slideTime + } + if (rememberDuration == null || rememberDuration < minRememberDuration) { + rememberDuration = minRememberDuration + } + + // Initialize the dependencies dependencies.foreach(_.initialize(zeroTime)) + } + + protected[streaming] def validate() { + assert( + !mustCheckpoint || checkpointInterval != null, + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + + " Please use DStream.checkpoint() to set the interval." + ) + + assert( + checkpointInterval == null || checkpointInterval >= slideTime, + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which is lower than its slide time (" + slideTime + "). " + + "Please set it to at least " + slideTime + "." + ) + + assert( + checkpointInterval == null || checkpointInterval.isMultipleOf(slideTime), + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which not a multiple of its slide time (" + slideTime + "). " + + "Please set it to a multiple " + slideTime + "." + ) + + assert( + checkpointInterval == null || storageLevel != StorageLevel.NONE, + "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + + "level has not been set to enable persisting. Please use DStream.persist() to set the " + + "storage level to use memory for better checkpointing performance." + ) + + assert( + checkpointInterval == null || rememberDuration > checkpointInterval, + "The remember duration for " + this.getClass.getSimpleName + " has been set to " + + rememberDuration + " which is not more than the checkpoint interval (" + + checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." + ) + + dependencies.foreach(_.validate()) + + logInfo("Slide time = " + slideTime) + logInfo("Storage level = " + storageLevel) + logInfo("Checkpoint interval = " + checkpointInterval) + logInfo("Remember duration = " + rememberDuration) logInfo("Initialized " + this) } @@ -120,17 +183,12 @@ extends Serializable with Logging { dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def setRememberDuration(duration: Time = slideTime) { - if (duration == null) { - throw new Exception("Duration for remembering RDDs cannot be set to null for " + this) - } else if (rememberDuration != null && duration < rememberDuration) { - logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration - + " to " + duration + " for " + this) - } else { + protected[streaming] def setRememberDuration(duration: Time) { + if (duration != null && duration > rememberDuration) { rememberDuration = duration - dependencies.foreach(_.setRememberDuration(parentRememberDuration)) logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) } + dependencies.foreach(_.setRememberDuration(parentRememberDuration)) } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ @@ -163,12 +221,13 @@ extends Serializable with Logging { if (isTimeValid(time)) { compute(time) match { case Some(newRDD) => - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { + if (storageLevel != StorageLevel.NONE) { newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time) + } + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.checkpoint() + logInfo("Marking RDD for time " + time + " for checkpointing at time " + time) } generatedRDDs.put(time, newRDD) Some(newRDD) @@ -199,7 +258,7 @@ extends Serializable with Logging { } } - def forgetOldRDDs(time: Time) { + protected[streaming] def forgetOldRDDs(time: Time) { val keys = generatedRDDs.keys var numForgotten = 0 keys.foreach(t => { @@ -213,12 +272,35 @@ extends Serializable with Logging { dependencies.foreach(_.forgetOldRDDs(time)) } + protected[streaming] def updateCheckpointData() { + checkpointData.clear() + generatedRDDs.foreach { + case(time, rdd) => { + logDebug("Adding checkpointed RDD for time " + time) + val data = rdd.getCheckpointData() + if (data != null) { + checkpointData += ((time, data)) + } + } + } + } + + protected[streaming] def restoreCheckpointData() { + checkpointData.foreach { + case(time, data) => { + logInfo("Restoring checkpointed RDD for time " + time) + generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) + } + } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { logDebug(this.getClass().getSimpleName + ".writeObject used") if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { + updateCheckpointData() oos.defaultWriteObject() } else { val msg = "Object of " + this.getClass.getName + " is being serialized " + @@ -239,6 +321,8 @@ extends Serializable with Logging { private def readObject(ois: ObjectInputStream) { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() + generatedRDDs = new HashMap[Time, RDD[T]] () + restoreCheckpointData() } /** diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index ac44d7a2a6..f8922ec790 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -22,11 +22,8 @@ final class DStreamGraph extends Serializable with Logging { } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) - outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values - if (rememberDuration != null) { - // if custom rememberDuration has been provided, set the rememberDuration - outputStreams.foreach(_.setRememberDuration(rememberDuration)) - } + outputStreams.foreach(_.setRememberDuration(rememberDuration)) + outputStreams.foreach(_.validate) inputStreams.par.foreach(_.start()) } } diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 1c57d5f855..6df82c0df3 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -21,15 +21,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( partitioner: Partitioner ) extends DStream[(K,V)](parent.ssc) { - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_windowTime.isMultipleOf(parent.slideTime), + "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_slideTime.isMultipleOf(parent.slideTime), + "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + super.persist(StorageLevel.MEMORY_ONLY) + + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) def windowTime: Time = _windowTime @@ -37,15 +41,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( override def slideTime: Time = _slideTime - //TODO: This is wrong. This should depend on the checkpointInterval + override val mustCheckpoint = true + override def parentRememberDuration: Time = rememberDuration + windowTime - override def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[(K,V)] = { - super.persist(storageLevel, checkpointLevel, checkpointInterval) - reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval) + override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + super.persist(storageLevel) + reducedStream.persist(storageLevel) + this + } + + override def checkpoint(interval: Time): DStream[(K, V)] = { + super.checkpoint(interval) + reducedStream.checkpoint(interval) this } diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index 086752ac55..0211df1343 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -23,51 +23,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { + super.persist(StorageLevel.MEMORY_ONLY) + override def dependencies = List(parent) override def slideTime = parent.slideTime - override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { - generatedRDDs.get(time) match { - case Some(oldRDD) => { - if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { - val r = oldRDD - val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) - val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { - override val partitioner = oldRDD.partitioner - } - generatedRDDs.update(time, checkpointedRDD) - logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id) - Some(checkpointedRDD) - } else { - Some(oldRDD) - } - } - case None => { - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } - generatedRDDs.put(time, newRDD) - Some(newRDD) - } - case None => { - None - } - } - } else { - None - } - } - } - } - + override val mustCheckpoint = true + override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index b3148eaa97..3838e84113 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -15,6 +15,8 @@ import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path +import java.util.UUID class StreamingContext ( sc_ : SparkContext, @@ -26,7 +28,7 @@ class StreamingContext ( def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = this(new SparkContext(master, frameworkName, sparkHome, jars), null) - def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + def this(path: String) = this(null, Checkpoint.load(path)) def this(cp_ : Checkpoint) = this(null, cp_) @@ -51,7 +53,6 @@ class StreamingContext ( val graph: DStreamGraph = { if (isCheckpointPresent) { - cp_.graph.setContext(this) cp_.graph } else { @@ -62,7 +63,15 @@ class StreamingContext ( val nextNetworkInputStreamId = new AtomicInteger(0) var networkInputTracker: NetworkInputTracker = null - private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + private[streaming] var checkpointDir: String = { + if (isCheckpointPresent) { + sc.setCheckpointDir(cp_.checkpointDir, true) + cp_.checkpointDir + } else { + null + } + } + private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null private[streaming] var receiverJobThread: Thread = null private[streaming] var scheduler: Scheduler = null @@ -75,9 +84,15 @@ class StreamingContext ( graph.setRememberDuration(duration) } - def setCheckpointDetails(file: String, interval: Time) { - checkpointFile = file - checkpointInterval = interval + def checkpoint(dir: String, interval: Time) { + if (dir != null) { + sc.setCheckpointDir(new Path(dir, "rdds-" + UUID.randomUUID.toString).toString) + checkpointDir = dir + checkpointInterval = interval + } else { + checkpointDir = null + checkpointInterval = null + } } private[streaming] def getInitialCheckpoint(): Checkpoint = { @@ -170,16 +185,12 @@ class StreamingContext ( graph.addOutputStream(outputStream) } - def validate() { - assert(graph != null, "Graph is null") - graph.validate() - } - /** * This function starts the execution of the streams. */ def start() { - validate() + assert(graph != null, "Graph is null") + graph.validate() val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true @@ -216,7 +227,8 @@ class StreamingContext ( } def doCheckpoint(currentTime: Time) { - new Checkpoint(this, currentTime).saveToFile(checkpointFile) + new Checkpoint(this, currentTime).save(checkpointDir) + } } diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 9ddb65249a..2ba6502971 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -25,6 +25,10 @@ case class Time(millis: Long) { def isMultipleOf(that: Time): Boolean = (this.millis % that.millis == 0) + def min(that: Time): Time = if (this < that) this else that + + def max(that: Time): Time = if (this > that) this else that + def isZero: Boolean = (this.millis == 0) override def toString: String = (millis.toString + " ms") diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala index df96a811da..21a83c0fde 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -10,20 +10,20 @@ object FileStreamWithCheckpoint { def main(args: Array[String]) { if (args.size != 3) { - println("FileStreamWithCheckpoint ") - println("FileStreamWithCheckpoint restart ") + println("FileStreamWithCheckpoint ") + println("FileStreamWithCheckpoint restart ") System.exit(-1) } val directory = new Path(args(1)) - val checkpointFile = args(2) + val checkpointDir = args(2) val ssc: StreamingContext = { if (args(0) == "restart") { // Recreated streaming context from specified checkpoint file - new StreamingContext(checkpointFile) + new StreamingContext(checkpointDir) } else { @@ -34,7 +34,7 @@ object FileStreamWithCheckpoint { // Create new streaming context val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") ssc_.setBatchDuration(Seconds(1)) - ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + ssc_.checkpoint(checkpointDir, Seconds(1)) // Setup the streaming computation val inputStream = ssc_.textFileStream(directory.toString) diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 57fd10f0a5..750cb7445f 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -41,9 +41,8 @@ object TopKWordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(Milliseconds(chkptMs)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { val taken = new Array[(String, Long)](k) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 0d2e62b955..865026033e 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -100,10 +100,9 @@ object WordCount2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, - StorageLevel.MEMORY_ONLY_2, - //new StorageLevel(false, true, true, 3), - Milliseconds(chkptMillis.toLong)) + + windowedCounts.persist().checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index abfd12890f..d1ea9a9cd5 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -41,9 +41,9 @@ object WordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(chkptMs) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala index 9d44da2b11..6a9c8a9a69 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -57,11 +57,13 @@ object WordMax2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKey(add _, reduceTasks.toInt) - .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) .reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - // Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 6dcedcf463..dfe31b5771 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -2,52 +2,95 @@ package spark.streaming import spark.streaming.StreamingContext._ import java.io.File +import collection.mutable.ArrayBuffer +import runtime.RichInt +import org.scalatest.BeforeAndAfter +import org.apache.hadoop.fs.Path +import org.apache.commons.io.FileUtils -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { + + before { + FileUtils.deleteDirectory(new File(checkpointDir)) + } + + after { + FileUtils.deleteDirectory(new File(checkpointDir)) + } override def framework() = "CheckpointSuite" - override def checkpointFile() = "checkpoint" + override def batchDuration() = Seconds(1) + + override def checkpointDir() = "checkpoint" + + override def checkpointInterval() = batchDuration def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], expectedOutput: Seq[Seq[V]], - useSet: Boolean = false + initialNumBatches: Int ) { // Current code assumes that: // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val initialNumBatches = input.size / 2 val nextNumBatches = totalNumBatches - initialNumBatches val initialNumExpectedOutputs = initialNumBatches + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs // Do half the computation (half the number of batches), create checkpoint file and quit val ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) - verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) Thread.sleep(1000) // Restart and complete the computation from checkpoint file - val sscNew = new StreamingContext(checkpointFile) - sscNew.setCheckpointDetails(null, null) - val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size) - verifyOutput[V](outputNew, expectedOutput, useSet) - - new File(checkpointFile).delete() - new File(checkpointFile + ".bk").delete() - new File("." + checkpointFile + ".crc").delete() - new File("." + checkpointFile + ".bk.crc").delete() + val sscNew = new StreamingContext(checkpointDir) + //sscNew.checkpoint(null, null) + val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) + verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) } - test("simple per-batch operation") { + + test("map and reduceByKey") { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), - true + 3 ) } + + test("reduceByKeyAndWindowInv") { + val n = 10 + val w = 4 + val input = (1 to n).map(x => Seq("a")).toSeq + val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val operation = (st: DStream[String]) => { + st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1)) + } + for (i <- Seq(3, 5, 7)) { + testCheckpointedOperation(input, operation, output, i) + } + } + + test("updateStateByKey") { + val input = (1 to 10).map(_ => Seq("a")).toSeq + val output = (1 to 10).map(x => Seq(("a", x))).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(Seconds(5)) + .map(t => (t._1, t._2.self)) + } + for (i <- Seq(3, 5, 7)) { + testCheckpointedOperation(input, operation, output, i) + } + } + } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index c9bc454f91..e441feea19 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -5,10 +5,16 @@ import util.ManualClock import collection.mutable.ArrayBuffer import org.scalatest.FunSuite import collection.mutable.SynchronizedBuffer +import java.io.{ObjectInputStream, IOException} + +/** + * This is a input stream just for the testsuites. This is equivalent to a checkpointable, + * replayable, reliable message queue like Kafka. It requires a sequence as input, and + * returns the i_th element at the i_th batch unde manual clock. + */ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) extends InputDStream[T](ssc_) { - var currentIndex = 0 def start() {} @@ -23,17 +29,32 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ ssc.sc.makeRDD(Seq[T](), numPartitions) } logInfo("Created RDD " + rdd.id) - //currentIndex += 1 Some(rdd) } } +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + */ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - }) + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } +} +/** + * This is the base trait for Spark Streaming testsuites. This provides basic functionality + * to run user-defined set of input on user-defined stream operations, and verify the output. + */ trait TestSuiteBase extends FunSuite with Logging { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") @@ -44,7 +65,7 @@ trait TestSuiteBase extends FunSuite with Logging { def batchDuration() = Seconds(1) - def checkpointFile() = null.asInstanceOf[String] + def checkpointDir() = null.asInstanceOf[String] def checkpointInterval() = batchDuration @@ -60,8 +81,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval()) } // Setup the stream computation @@ -82,8 +103,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval()) } // Setup the stream computation diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index d7d8d5bd36..e282f0fdd5 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -283,7 +283,9 @@ class WindowOperationsSuite extends TestSuiteBase { test("reduceByKeyAndWindowInv - " + name) { val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime) + .persist() + .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing } testOperation(input, operation, expectedOutput, numBatches, true) } -- cgit v1.2.3 From 72b2303f99bd652fc4bdaa929f37731a7ba8f640 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 5 Nov 2012 11:41:36 -0800 Subject: Fixed major bugs in checkpointing. --- core/src/main/scala/spark/SparkContext.scala | 6 +- .../main/scala/spark/streaming/Checkpoint.scala | 24 ++-- .../src/main/scala/spark/streaming/DStream.scala | 47 +++++-- .../main/scala/spark/streaming/DStreamGraph.scala | 36 ++++-- .../src/main/scala/spark/streaming/Scheduler.scala | 1 - .../scala/spark/streaming/StreamingContext.scala | 8 +- .../scala/spark/streaming/CheckpointSuite.scala | 139 ++++++++++++++++----- .../test/scala/spark/streaming/TestSuiteBase.scala | 37 ++++-- .../spark/streaming/WindowOperationsSuite.scala | 6 +- 9 files changed, 217 insertions(+), 87 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7326971a9..d7b46bee38 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -474,8 +474,10 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { - dagScheduler.stop() - dagScheduler = null + if (dagScheduler != null) { + dagScheduler.stop() + dagScheduler = null + } taskScheduler = null // TODO: Cache.stop()? env.stop() diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index cf04c7031e..6b4b05103f 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -6,6 +6,7 @@ import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} +import sys.process.processInternal class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) @@ -52,17 +53,17 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) } } -object Checkpoint { +object Checkpoint extends Logging { def load(path: String): Checkpoint = { val fs = new Path(path).getFileSystem(new Configuration()) - val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk")) - var lastException: Exception = null - var lastExceptionFile: String = null + val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) + var detailedLog: String = "" attempts.foreach(file => { if (fs.exists(file)) { + logInfo("Attempting to load checkpoint from file '" + file + "'") try { val fis = fs.open(file) // ObjectInputStream uses the last defined user-defined class loader in the stack @@ -75,21 +76,18 @@ object Checkpoint { ois.close() fs.close() cp.validate() - println("Checkpoint successfully loaded from file " + file) + logInfo("Checkpoint successfully loaded from file '" + file + "'") return cp } catch { case e: Exception => - lastException = e - lastExceptionFile = file.toString + logError("Error loading checkpoint from file '" + file + "'", e) } + } else { + logWarning("Could not load checkpoint from file '" + file + "' as it does not exist") } - }) - if (lastException == null) { - throw new Exception("Could not load checkpoint from path '" + path + "'") - } else { - throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException) - } + }) + throw new Exception("Could not load checkpoint from path '" + path + "'") } def fromBytes(bytes: Array[Byte]): Checkpoint = { diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index de51c5d34a..2fecbe0acf 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -14,6 +14,8 @@ import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import scala.Some import collection.mutable +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -42,6 +44,7 @@ extends Serializable with Logging { */ // RDDs generated, marked as protected[streaming] so that testsuites can access it + @transient protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream @@ -112,7 +115,7 @@ extends Serializable with Logging { // Set the minimum value of the rememberDuration if not already set var minRememberDuration = slideTime if (checkpointInterval != null && minRememberDuration <= checkpointInterval) { - minRememberDuration = checkpointInterval + slideTime + minRememberDuration = checkpointInterval * 2 // times 2 just to be sure that the latest checkpoint is not forgetten } if (rememberDuration == null || rememberDuration < minRememberDuration) { rememberDuration = minRememberDuration @@ -265,33 +268,59 @@ extends Serializable with Logging { if (t <= (time - rememberDuration)) { generatedRDDs.remove(t) numForgotten += 1 - //logInfo("Forgot RDD of time " + t + " from " + this) + logInfo("Forgot RDD of time " + t + " from " + this) } }) logInfo("Forgot " + numForgotten + " RDDs from " + this) dependencies.foreach(_.forgetOldRDDs(time)) } + /** + * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of this stream. + * Along with that it forget old checkpoint files. + */ protected[streaming] def updateCheckpointData() { + + // TODO (tdas): This code can be simplified. Its kept verbose to aid debugging. + val checkpointedRDDs = generatedRDDs.filter(_._2.getCheckpointData() != null) + val removedCheckpointData = checkpointData.filter(x => !generatedRDDs.contains(x._1)) + checkpointData.clear() - generatedRDDs.foreach { - case(time, rdd) => { - logDebug("Adding checkpointed RDD for time " + time) + checkpointedRDDs.foreach { + case (time, rdd) => { val data = rdd.getCheckpointData() - if (data != null) { - checkpointData += ((time, data)) + assert(data != null) + checkpointData += ((time, data)) + logInfo("Added checkpointed RDD " + rdd + " for time " + time + " to stream checkpoint") + } + } + + dependencies.foreach(_.updateCheckpointData()) + // If at least one checkpoint is present, then delete old checkpoints + if (checkpointData.size > 0) { + // Delete the checkpoint RDD files that are not needed any more + removedCheckpointData.foreach { + case (time: Time, file: String) => { + val path = new Path(file) + val fs = path.getFileSystem(new Configuration()) + fs.delete(path, true) + logInfo("Deleted checkpoint file '" + file + "' for time " + time) } } } + + logInfo("Updated checkpoint data") } protected[streaming] def restoreCheckpointData() { + logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") checkpointData.foreach { case(time, data) => { - logInfo("Restoring checkpointed RDD for time " + time) + logInfo("Restoring checkpointed RDD for time " + time + " from file") generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) } } + dependencies.foreach(_.restoreCheckpointData()) } @throws(classOf[IOException]) @@ -300,7 +329,6 @@ extends Serializable with Logging { if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { - updateCheckpointData() oos.defaultWriteObject() } else { val msg = "Object of " + this.getClass.getName + " is being serialized " + @@ -322,7 +350,6 @@ extends Serializable with Logging { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() generatedRDDs = new HashMap[Time, RDD[T]] () - restoreCheckpointData() } /** diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index f8922ec790..7437f4402d 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -4,7 +4,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import spark.Logging -final class DStreamGraph extends Serializable with Logging { +final private[streaming] class DStreamGraph extends Serializable with Logging { initLogging() private val inputStreams = new ArrayBuffer[InputDStream[_]]() @@ -15,7 +15,7 @@ final class DStreamGraph extends Serializable with Logging { private[streaming] var rememberDuration: Time = null private[streaming] var checkpointInProgress = false - def start(time: Time) { + private[streaming] def start(time: Time) { this.synchronized { if (zeroTime != null) { throw new Exception("DStream graph computation already started") @@ -28,7 +28,7 @@ final class DStreamGraph extends Serializable with Logging { } } - def stop() { + private[streaming] def stop() { this.synchronized { inputStreams.par.foreach(_.stop()) } @@ -40,7 +40,7 @@ final class DStreamGraph extends Serializable with Logging { } } - def setBatchDuration(duration: Time) { + private[streaming] def setBatchDuration(duration: Time) { this.synchronized { if (batchDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -50,7 +50,7 @@ final class DStreamGraph extends Serializable with Logging { batchDuration = duration } - def setRememberDuration(duration: Time) { + private[streaming] def setRememberDuration(duration: Time) { this.synchronized { if (rememberDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -60,37 +60,49 @@ final class DStreamGraph extends Serializable with Logging { rememberDuration = duration } - def addInputStream(inputStream: InputDStream[_]) { + private[streaming] def addInputStream(inputStream: InputDStream[_]) { this.synchronized { inputStream.setGraph(this) inputStreams += inputStream } } - def addOutputStream(outputStream: DStream[_]) { + private[streaming] def addOutputStream(outputStream: DStream[_]) { this.synchronized { outputStream.setGraph(this) outputStreams += outputStream } } - def getInputStreams() = inputStreams.toArray + private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray } - def getOutputStreams() = outputStreams.toArray + private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray } - def generateRDDs(time: Time): Seq[Job] = { + private[streaming] def generateRDDs(time: Time): Seq[Job] = { this.synchronized { outputStreams.flatMap(outputStream => outputStream.generateJob(time)) } } - def forgetOldRDDs(time: Time) { + private[streaming] def forgetOldRDDs(time: Time) { this.synchronized { outputStreams.foreach(_.forgetOldRDDs(time)) } } - def validate() { + private[streaming] def updateCheckpointData() { + this.synchronized { + outputStreams.foreach(_.updateCheckpointData()) + } + } + + private[streaming] def restoreCheckpointData() { + this.synchronized { + outputStreams.foreach(_.restoreCheckpointData()) + } + } + + private[streaming] def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 7d52e2eddf..2b3f5a4829 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -58,7 +58,6 @@ extends Logging { graph.forgetOldRDDs(time) if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { ssc.doCheckpoint(time) - logInfo("Checkpointed at time " + time) } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 3838e84113..fb36ab9dc9 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -54,6 +54,7 @@ class StreamingContext ( val graph: DStreamGraph = { if (isCheckpointPresent) { cp_.graph.setContext(this) + cp_.graph.restoreCheckpointData() cp_.graph } else { new DStreamGraph() @@ -218,17 +219,16 @@ class StreamingContext ( if (scheduler != null) scheduler.stop() if (networkInputTracker != null) networkInputTracker.stop() if (receiverJobThread != null) receiverJobThread.interrupt() - sc.stop() + sc.stop() + logInfo("StreamingContext stopped successfully") } catch { case e: Exception => logWarning("Error while stopping", e) } - - logInfo("StreamingContext stopped") } def doCheckpoint(currentTime: Time) { + graph.updateCheckpointData() new Checkpoint(this, currentTime).save(checkpointDir) - } } diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index dfe31b5771..aa8ded513c 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -2,11 +2,11 @@ package spark.streaming import spark.streaming.StreamingContext._ import java.io.File -import collection.mutable.ArrayBuffer import runtime.RichInt import org.scalatest.BeforeAndAfter -import org.apache.hadoop.fs.Path import org.apache.commons.io.FileUtils +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import util.ManualClock class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { @@ -18,39 +18,83 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(new File(checkpointDir)) } - override def framework() = "CheckpointSuite" + override def framework = "CheckpointSuite" - override def batchDuration() = Seconds(1) + override def batchDuration = Milliseconds(500) - override def checkpointDir() = "checkpoint" + override def checkpointDir = "checkpoint" - override def checkpointInterval() = batchDuration + override def checkpointInterval = batchDuration - def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { + override def actuallyWait = true - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + test("basic stream+rdd recovery") { - // Do half the computation (half the number of batches), create checkpoint file and quit - val ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) + assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - // Restart and complete the computation from checkpoint file + val checkpointingInterval = Seconds(2) + + // this ensure checkpointing occurs at least once + val firstNumBatches = (checkpointingInterval.millis / batchDuration.millis) * 2 + val secondNumBatches = firstNumBatches + + // Setup the streams + val input = (1 to 10).map(_ => Seq("a")).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(checkpointingInterval) + .map(t => (t._1, t._2.self)) + } + val ssc = setupStreams(input, operation) + val stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + + // Run till a time such that at least one RDD in the stream should have been checkpointed + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logInfo("Manual clock before advancing = " + clock.time) + for (i <- 1 to firstNumBatches.toInt) { + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + Thread.sleep(batchDuration.milliseconds) + + // Check whether some RDD has been checkpointed or not + logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]") + assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream") + stateStream.checkpointData.foreach { + case (time, data) => { + val file = new File(data.toString) + assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " does not exist") + } + } + val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString)) + + // Run till a further time such that previous checkpoint files in the stream would be deleted + logInfo("Manual clock before advancing = " + clock.time) + for (i <- 1 to secondNumBatches.toInt) { + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + Thread.sleep(batchDuration.milliseconds) + + // Check whether the earlier checkpoint files are deleted + checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) + + // Restart stream computation using the checkpoint file and check whether + // checkpointed RDDs have been restored or not + ssc.stop() val sscNew = new StreamingContext(checkpointDir) - //sscNew.checkpoint(null, null) - val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + val stateStreamNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]") + assert(!stateStreamNew.generatedRDDs.isEmpty, "No restored RDDs in state stream") + sscNew.stop() } @@ -69,9 +113,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val input = (1 to n).map(x => Seq("a")).toSeq val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { - st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1)) + st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, batchDuration * 4, batchDuration) } - for (i <- Seq(3, 5, 7)) { + for (i <- Seq(2, 3, 4)) { testCheckpointedOperation(input, operation, output, i) } } @@ -85,12 +129,45 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } st.map(x => (x, 1)) .updateStateByKey[RichInt](updateFunc) - .checkpoint(Seconds(5)) + .checkpoint(Seconds(2)) .map(t => (t._1, t._2.self)) } - for (i <- Seq(3, 5, 7)) { + for (i <- Seq(2, 3, 4)) { testCheckpointedOperation(input, operation, output, i) } } + + + def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + initialNumBatches: Int + ) { + + // Current code assumes that: + // number of inputs = number of outputs = number of batches to be run + val totalNumBatches = input.size + val nextNumBatches = totalNumBatches - initialNumBatches + val initialNumExpectedOutputs = initialNumBatches + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + + // Do half the computation (half the number of batches), create checkpoint file and quit + + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) + Thread.sleep(1000) + + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + val sscNew = new StreamingContext(checkpointDir) + val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) + verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + } } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index e441feea19..b8c7f99603 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -57,21 +57,21 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu */ trait TestSuiteBase extends FunSuite with Logging { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + def framework = "TestSuiteBase" - def framework() = "TestSuiteBase" + def master = "local[2]" - def master() = "local[2]" + def batchDuration = Seconds(1) - def batchDuration() = Seconds(1) + def checkpointDir = null.asInstanceOf[String] - def checkpointDir() = null.asInstanceOf[String] + def checkpointInterval = batchDuration - def checkpointInterval() = batchDuration + def numInputPartitions = 2 - def numInputPartitions() = 2 + def maxWaitTimeMillis = 10000 - def maxWaitTimeMillis() = 10000 + def actuallyWait = false def setupStreams[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], @@ -82,7 +82,7 @@ trait TestSuiteBase extends FunSuite with Logging { val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval()) + ssc.checkpoint(checkpointDir, checkpointInterval) } // Setup the stream computation @@ -104,7 +104,7 @@ trait TestSuiteBase extends FunSuite with Logging { val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval()) + ssc.checkpoint(checkpointDir, checkpointInterval) } // Setup the stream computation @@ -118,12 +118,19 @@ trait TestSuiteBase extends FunSuite with Logging { ssc } + /** + * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and + * returns the collected output. It will wait until `numExpectedOutput` number of + * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. + */ def runStreams[V: ClassManifest]( ssc: StreamingContext, numBatches: Int, numExpectedOutput: Int ): Seq[Seq[V]] = { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) @@ -139,7 +146,15 @@ trait TestSuiteBase extends FunSuite with Logging { // Advance manual clock val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) - clock.addToTime(numBatches * batchDuration.milliseconds) + if (actuallyWait) { + for (i <- 1 to numBatches) { + logInfo("Actually waiting for " + batchDuration) + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + } else { + clock.addToTime(numBatches * batchDuration.milliseconds) + } logInfo("Manual clock after advancing = " + clock.time) // Wait until expected number of output items have been generated diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index e282f0fdd5..3e20e16708 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -5,11 +5,11 @@ import collection.mutable.ArrayBuffer class WindowOperationsSuite extends TestSuiteBase { - override def framework() = "WindowOperationsSuite" + override def framework = "WindowOperationsSuite" - override def maxWaitTimeMillis() = 20000 + override def maxWaitTimeMillis = 20000 - override def batchDuration() = Seconds(1) + override def batchDuration = Seconds(1) val largerSlideInput = Seq( Seq(("a", 1)), -- cgit v1.2.3 From 395167f2b2a1906cde23b1f3ddc2808514bce47b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 5 Nov 2012 16:11:50 -0800 Subject: Made more bug fixes for checkpointing. --- .../main/scala/spark/streaming/Checkpoint.scala | 1 - .../src/main/scala/spark/streaming/DStream.scala | 77 ++++++++++++---------- .../main/scala/spark/streaming/DStreamGraph.scala | 4 +- .../scala/spark/streaming/StreamingContext.scala | 2 +- .../src/main/scala/spark/streaming/Time.scala | 2 +- .../scala/spark/streaming/CheckpointSuite.scala | 77 ++++++++++++++-------- 6 files changed, 97 insertions(+), 66 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 6b4b05103f..1643f45ffb 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -6,7 +6,6 @@ import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} -import sys.process.processInternal class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 2fecbe0acf..922ff5088d 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -1,6 +1,7 @@ package spark.streaming -import spark.streaming.StreamingContext._ +import StreamingContext._ +import Time._ import spark._ import spark.SparkContext._ @@ -12,8 +13,7 @@ import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import scala.Some -import collection.mutable + import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration @@ -206,10 +206,10 @@ extends Serializable with Logging { } /** - * This method either retrieves a precomputed RDD of this DStream, - * or computes the RDD (if the time is valid) + * Retrieves a precomputed RDD of this DStream, or computes the RDD. This is an internal + * method that should not be called directly. */ - def getOrCompute(time: Time): Option[RDD[T]] = { + protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { // If this DStream was not initialized (i.e., zeroTime not set), then do it // If RDD was already generated, then retrieve it from HashMap generatedRDDs.get(time) match { @@ -245,10 +245,12 @@ extends Serializable with Logging { } /** - * This method generates a SparkStreaming job for the given time - * and may required to be overriden by subclasses + * Generates a SparkStreaming job for the given time. This is an internal method that + * should not be called directly. This default implementation creates a job + * that materializes the corresponding RDD. Subclasses of DStream may override this + * (eg. PerRDDForEachDStream). */ - def generateJob(time: Time): Option[Job] = { + protected[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { case Some(rdd) => { val jobFunc = () => { @@ -261,6 +263,9 @@ extends Serializable with Logging { } } + /** + * Dereferences RDDs that are older than rememberDuration. + */ protected[streaming] def forgetOldRDDs(time: Time) { val keys = generatedRDDs.keys var numForgotten = 0 @@ -276,42 +281,46 @@ extends Serializable with Logging { } /** - * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of this stream. - * Along with that it forget old checkpoint files. + * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of + * this stream. This is an internal method that should not be called directly. This is + * a default implementation that saves only the file names of the checkpointed RDDs to + * checkpointData. Subclasses of DStream (especially those of InputDStream) may override + * this method to save custom checkpoint data. */ - protected[streaming] def updateCheckpointData() { - - // TODO (tdas): This code can be simplified. Its kept verbose to aid debugging. - val checkpointedRDDs = generatedRDDs.filter(_._2.getCheckpointData() != null) - val removedCheckpointData = checkpointData.filter(x => !generatedRDDs.contains(x._1)) - - checkpointData.clear() - checkpointedRDDs.foreach { - case (time, rdd) => { - val data = rdd.getCheckpointData() - assert(data != null) - checkpointData += ((time, data)) - logInfo("Added checkpointed RDD " + rdd + " for time " + time + " to stream checkpoint") - } + protected[streaming] def updateCheckpointData(currentTime: Time) { + val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) + .map(x => (x._1, x._2.getCheckpointData())) + val oldCheckpointData = checkpointData.clone() + if (newCheckpointData.size > 0) { + checkpointData.clear() + checkpointData ++= newCheckpointData + } + + dependencies.foreach(_.updateCheckpointData(currentTime)) + + newCheckpointData.foreach { + case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } } - dependencies.foreach(_.updateCheckpointData()) - // If at least one checkpoint is present, then delete old checkpoints - if (checkpointData.size > 0) { - // Delete the checkpoint RDD files that are not needed any more - removedCheckpointData.foreach { - case (time: Time, file: String) => { - val path = new Path(file) + if (newCheckpointData.size > 0) { + (oldCheckpointData -- newCheckpointData.keySet).foreach { + case (time, data) => { + val path = new Path(data.toString) val fs = path.getFileSystem(new Configuration()) fs.delete(path, true) - logInfo("Deleted checkpoint file '" + file + "' for time " + time) + logInfo("Deleted checkpoint file '" + path + "' for time " + time) } } } - logInfo("Updated checkpoint data") } + /** + * Restores the RDDs in generatedRDDs from the checkpointData. This is an internal method + * that should not be called directly. This is a default implementation that recreates RDDs + * from the checkpoint file names stored in checkpointData. Subclasses of DStream that + * override the updateCheckpointData() method would also need to override this method. + */ protected[streaming] def restoreCheckpointData() { logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") checkpointData.foreach { diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index 7437f4402d..246522838a 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -90,9 +90,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } - private[streaming] def updateCheckpointData() { + private[streaming] def updateCheckpointData(time: Time) { this.synchronized { - outputStreams.foreach(_.updateCheckpointData()) + outputStreams.foreach(_.updateCheckpointData(time)) } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index fb36ab9dc9..25caaf7d39 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -227,7 +227,7 @@ class StreamingContext ( } def doCheckpoint(currentTime: Time) { - graph.updateCheckpointData() + graph.updateCheckpointData(currentTime) new Checkpoint(this, currentTime).save(checkpointDir) } } diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 2ba6502971..480d292d7c 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -43,7 +43,7 @@ object Time { implicit def toTime(long: Long) = Time(long) - implicit def toLong(time: Time) = time.milliseconds + implicit def toLong(time: Time) = time.milliseconds } object Milliseconds { diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index aa8ded513c..9fdfd50be2 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -6,7 +6,7 @@ import runtime.RichInt import org.scalatest.BeforeAndAfter import org.apache.commons.io.FileUtils import collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import util.ManualClock +import util.{Clock, ManualClock} class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { @@ -31,12 +31,14 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { test("basic stream+rdd recovery") { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") + assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration") + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - val checkpointingInterval = Seconds(2) + val stateStreamCheckpointInterval = Seconds(2) // this ensure checkpointing occurs at least once - val firstNumBatches = (checkpointingInterval.millis / batchDuration.millis) * 2 + val firstNumBatches = (stateStreamCheckpointInterval.millis / batchDuration.millis) * 2 val secondNumBatches = firstNumBatches // Setup the streams @@ -47,7 +49,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } st.map(x => (x, 1)) .updateStateByKey[RichInt](updateFunc) - .checkpoint(checkpointingInterval) + .checkpoint(stateStreamCheckpointInterval) .map(t => (t._1, t._2.self)) } val ssc = setupStreams(input, operation) @@ -56,35 +58,22 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run till a time such that at least one RDD in the stream should have been checkpointed ssc.start() val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.time) - for (i <- 1 to firstNumBatches.toInt) { - clock.addToTime(batchDuration.milliseconds) - Thread.sleep(batchDuration.milliseconds) - } - logInfo("Manual clock after advancing = " + clock.time) - Thread.sleep(batchDuration.milliseconds) + advanceClock(clock, firstNumBatches) // Check whether some RDD has been checkpointed or not logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]") - assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream") + assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before first failure") stateStream.checkpointData.foreach { case (time, data) => { val file = new File(data.toString) - assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " does not exist") + assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") } } - val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString)) // Run till a further time such that previous checkpoint files in the stream would be deleted - logInfo("Manual clock before advancing = " + clock.time) - for (i <- 1 to secondNumBatches.toInt) { - clock.addToTime(batchDuration.milliseconds) - Thread.sleep(batchDuration.milliseconds) - } - logInfo("Manual clock after advancing = " + clock.time) - Thread.sleep(batchDuration.milliseconds) - - // Check whether the earlier checkpoint files are deleted + // and check whether the earlier checkpoint files are deleted + val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString)) + advanceClock(clock, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) // Restart stream computation using the checkpoint file and check whether @@ -93,11 +82,35 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val sscNew = new StreamingContext(checkpointDir) val stateStreamNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]") - assert(!stateStreamNew.generatedRDDs.isEmpty, "No restored RDDs in state stream") + assert(!stateStreamNew.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from first failure") + + + // Run one batch to generate a new checkpoint file + sscNew.start() + val clockNew = sscNew.scheduler.clock.asInstanceOf[ManualClock] + advanceClock(clockNew, 1) + + // Check whether some RDD is present in the checkpoint data or not + assert(!stateStreamNew.checkpointData.isEmpty, "No checkpointed RDDs in state stream before second failure") + stateStream.checkpointData.foreach { + case (time, data) => { + val file = new File(data.toString) + assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") + } + } + + // Restart stream computation from the new checkpoint file to see whether that file has + // correct checkpoint data sscNew.stop() + val sscNewNew = new StreamingContext(checkpointDir) + val stateStreamNewNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]") + assert(!stateStreamNewNew.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from second failure") + sscNewNew.start() + advanceClock(sscNewNew.scheduler.clock.asInstanceOf[ManualClock], 1) + sscNewNew.stop() } - test("map and reduceByKey") { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), @@ -163,11 +176,21 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Restart and complete the computation from checkpoint file logInfo( "\n-------------------------------------------\n" + - " Restarting stream computation " + - "\n-------------------------------------------\n" + " Restarting stream computation " + + "\n-------------------------------------------\n" ) val sscNew = new StreamingContext(checkpointDir) val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) } + + def advanceClock(clock: ManualClock, numBatches: Long) { + logInfo("Manual clock before advancing = " + clock.time) + for (i <- 1 to numBatches.toInt) { + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + Thread.sleep(batchDuration.milliseconds) + } } \ No newline at end of file -- cgit v1.2.3 From f8bb719cd212f7e7f821c3f69b897985f47a2f83 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 5 Nov 2012 17:53:56 -0800 Subject: Added a few more comments to the checkpoint-related functions. --- streaming/src/main/scala/spark/streaming/DStream.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 922ff5088d..40744eac19 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -288,20 +288,27 @@ extends Serializable with Logging { * this method to save custom checkpoint data. */ protected[streaming] def updateCheckpointData(currentTime: Time) { + // Get the checkpointed RDDs from the generated RDDs val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) .map(x => (x._1, x._2.getCheckpointData())) + // Make a copy of the existing checkpoint data val oldCheckpointData = checkpointData.clone() + + // If the new checkpoint has checkpoints then replace existing with the new one if (newCheckpointData.size > 0) { checkpointData.clear() checkpointData ++= newCheckpointData } + // Make dependencies update their checkpoint data dependencies.foreach(_.updateCheckpointData(currentTime)) + // TODO: remove this, this is just for debugging newCheckpointData.foreach { case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } } + // If old checkpoint files have been removed from checkpoint data, then remove the files if (newCheckpointData.size > 0) { (oldCheckpointData -- newCheckpointData.keySet).foreach { case (time, data) => { @@ -322,6 +329,7 @@ extends Serializable with Logging { * override the updateCheckpointData() method would also need to override this method. */ protected[streaming] def restoreCheckpointData() { + // Create RDDs from the checkpoint data logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") checkpointData.foreach { case(time, data) => { -- cgit v1.2.3 From 0c1de43fc7a9fea8629907d5b331e466f18be418 Mon Sep 17 00:00:00 2001 From: Denny Date: Tue, 6 Nov 2012 09:41:42 -0800 Subject: Working on kafka. --- project/SparkBuild.scala | 4 +- .../scala/spark/streaming/StreamingContext.scala | 11 ++ .../spark/streaming/examples/KafkaWordCount.scala | 27 +++++ .../spark/streaming/input/KafkaInputDStream.scala | 121 +++++++++++++++++++++ 4 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala create mode 100644 streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 688bb16a03..f34736b1c4 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -156,7 +156,9 @@ object SparkBuild extends Build { def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") def streamingSettings = sharedSettings ++ Seq( - name := "spark-streaming" + name := "spark-streaming", + libraryDependencies ++= Seq( + "kafka" % "core-kafka_2.9.1" % "0.7.2") ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index b3148eaa97..4a78090597 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -86,6 +86,17 @@ class StreamingContext ( private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + def kafkaStream[T: ClassManifest]( + hostname: String, + port: Int, + groupId: String, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 + ): DStream[T] = { + val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, storageLevel) + graph.addInputStream(inputStream) + inputStream + } + def networkTextStream( hostname: String, port: Int, diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala new file mode 100644 index 0000000000..3f637150d1 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -0,0 +1,27 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext, KafkaInputDStream} +import spark.streaming.StreamingContext._ +import spark.storage.StorageLevel + +object KafkaWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: WordCountNetwork ") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "WordCountNetwork") + ssc.setBatchDuration(Seconds(2)) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.kafkaStream[String](args(1), args(2).toInt, "test_group") + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + + } +} diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala new file mode 100644 index 0000000000..427f398237 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -0,0 +1,121 @@ +package spark.streaming + +import java.nio.ByteBuffer +import java.util.Properties +import java.util.concurrent.{ArrayBlockingQueue, Executors} +import kafka.api.{FetchRequest} +import kafka.consumer.{Consumer, ConsumerConfig, KafkaStream} +import kafka.javaapi.consumer.SimpleConsumer +import kafka.javaapi.message.ByteBufferMessageSet +import kafka.message.{Message, MessageSet, MessageAndMetadata} +import kafka.utils.Utils +import scala.collection.JavaConversions._ +import spark._ +import spark.RDD +import spark.storage.StorageLevel + + +/** + * An input stream that pulls messages form a Kafka Broker. + */ +class KafkaInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + groupId: String, + storageLevel: StorageLevel, + timeout: Int = 10000, + bufferSize: Int = 1024000 + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + def createReceiver(): NetworkReceiver[T] = { + new KafkaReceiver(id, host, port, storageLevel, groupId, timeout).asInstanceOf[NetworkReceiver[T]] + } +} + +class KafkaReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel, groupId: String, timeout: Int) + extends NetworkReceiver[Any](streamId) { + + //var executorPool : = null + var blockPushingThread : Thread = null + + def onStop() { + blockPushingThread.interrupt() + } + + def onStart() { + + val executorPool = Executors.newFixedThreadPool(2) + + logInfo("Starting Kafka Consumer with groupId " + groupId) + + val zooKeeperEndPoint = host + ":" + port + logInfo("Connecting to " + zooKeeperEndPoint) + + // Specify some consumer properties + val props = new Properties() + props.put("zk.connect", zooKeeperEndPoint) + props.put("zk.connectiontimeout.ms", timeout.toString) + props.put("groupid", groupId) + + // Create the connection to the cluster + val consumerConfig = new ConsumerConfig(props) + val consumerConnector = Consumer.create(consumerConfig) + logInfo("Connected to " + zooKeeperEndPoint) + logInfo("") + logInfo("") + + // Specify which topics we are listening to + val topicCountMap = Map("test" -> 2) + val topicMessageStreams = consumerConnector.createMessageStreams(topicCountMap) + val streams = topicMessageStreams.get("test") + + // Queue that holds the blocks + val queue = new ArrayBlockingQueue[ByteBuffer](2) + + streams.getOrElse(Nil).foreach { stream => + executorPool.submit(new MessageHandler(stream, queue)) + } + + blockPushingThread = new DaemonThread { + override def run() { + logInfo("Starting BlockPushingThread.") + var nextBlockNumber = 0 + while (true) { + val buffer = queue.take() + val blockId = "input-" + streamId + "-" + nextBlockNumber + nextBlockNumber += 1 + pushBlock(blockId, buffer, storageLevel) + } + } + } + blockPushingThread.start() + + // while (true) { + // // Create a fetch request for topic “test”, partition 0, current offset, and fetch size of 1MB + // val fetchRequest = new FetchRequest("test", 0, offset, 1000000) + + // // get the message set from the consumer and print them out + // val messages = consumer.fetch(fetchRequest) + // for(msg <- messages.iterator) { + // logInfo("consumed: " + Utils.toString(msg.message.payload, "UTF-8")) + // // advance the offset after consuming each message + // offset = msg.offset + // queue.put(msg.message.payload) + // } + // } + } + + class MessageHandler(stream: KafkaStream[Message], queue: ArrayBlockingQueue[ByteBuffer]) extends Runnable { + def run() { + logInfo("Starting MessageHandler.") + while(true) { + stream.foreach { msgAndMetadata => + logInfo("Consumed: " + Utils.toString(msgAndMetadata.message.payload, "UTF-8")) + queue.put(msgAndMetadata.message.payload) + } + } + } + } + +} -- cgit v1.2.3 From fc3d0b602a08fdd182c2138506d1cd9952631f95 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 6 Nov 2012 17:23:31 -0800 Subject: Added FailureTestsuite for testing multiple, repeated master failures. --- .../main/scala/spark/streaming/Checkpoint.scala | 6 +- .../src/main/scala/spark/streaming/DStream.scala | 4 +- .../src/main/scala/spark/streaming/Scheduler.scala | 6 +- .../scala/spark/streaming/StreamingContext.scala | 17 +- .../scala/spark/streaming/CheckpointSuite.scala | 75 ++++---- .../test/scala/spark/streaming/FailureSuite.scala | 188 +++++++++++++++++++++ .../test/scala/spark/streaming/TestSuiteBase.scala | 9 +- 7 files changed, 256 insertions(+), 49 deletions(-) create mode 100644 streaming/src/test/scala/spark/streaming/FailureSuite.scala diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 1643f45ffb..a70fb8f73a 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -32,7 +32,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val file = new Path(path, "graph") val conf = new Configuration() val fs = file.getFileSystem(conf) - logDebug("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") + logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") if (fs.exists(file)) { val bkFile = new Path(file.getParent, file.getName + ".bk") FileUtil.copy(fs, file, fs, bkFile, true, true, conf) @@ -43,7 +43,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) oos.writeObject(this) oos.close() fs.close() - logInfo("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") + logInfo("Checkpoint of streaming context for time " + checkpointTime + " saved successfully to file '" + file + "'") } def toBytes(): Array[Byte] = { @@ -58,7 +58,6 @@ object Checkpoint extends Logging { val fs = new Path(path).getFileSystem(new Configuration()) val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) - var detailedLog: String = "" attempts.foreach(file => { if (fs.exists(file)) { @@ -76,6 +75,7 @@ object Checkpoint extends Logging { fs.close() cp.validate() logInfo("Checkpoint successfully loaded from file '" + file + "'") + logInfo("Checkpoint was generated at time " + cp.checkpointTime) return cp } catch { case e: Exception => diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 40744eac19..73096edec5 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -288,6 +288,7 @@ extends Serializable with Logging { * this method to save custom checkpoint data. */ protected[streaming] def updateCheckpointData(currentTime: Time) { + logInfo("Updating checkpoint data for time " + currentTime) // Get the checkpointed RDDs from the generated RDDs val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) .map(x => (x._1, x._2.getCheckpointData())) @@ -319,7 +320,7 @@ extends Serializable with Logging { } } } - logInfo("Updated checkpoint data") + logInfo("Updated checkpoint data for time " + currentTime) } /** @@ -338,6 +339,7 @@ extends Serializable with Logging { } } dependencies.foreach(_.restoreCheckpointData()) + logInfo("Restored checkpoint data") } @throws(classOf[IOException]) diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 2b3f5a4829..de0fb1f3ad 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -29,10 +29,12 @@ extends Logging { // on this first trigger time of the timer. if (ssc.isCheckpointPresent) { // If manual clock is being used for testing, then - // set manual clock to the last checkpointed time + // either set the manual clock to the last checkpointed time, + // or if the property is defined set it to that time if (clock.isInstanceOf[ManualClock]) { val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds - clock.asInstanceOf[ManualClock].setTime(lastTime) + val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong + clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) } timer.restart(graph.zeroTime.milliseconds) logInfo("Scheduler's timer restarted") diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 25caaf7d39..eb83aaee7a 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -18,7 +18,7 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import java.util.UUID -class StreamingContext ( +final class StreamingContext ( sc_ : SparkContext, cp_ : Checkpoint ) extends Logging { @@ -61,12 +61,12 @@ class StreamingContext ( } } - val nextNetworkInputStreamId = new AtomicInteger(0) - var networkInputTracker: NetworkInputTracker = null + private[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) + private[streaming] var networkInputTracker: NetworkInputTracker = null private[streaming] var checkpointDir: String = { if (isCheckpointPresent) { - sc.setCheckpointDir(cp_.checkpointDir, true) + sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(cp_.checkpointDir), true) cp_.checkpointDir } else { null @@ -87,7 +87,7 @@ class StreamingContext ( def checkpoint(dir: String, interval: Time) { if (dir != null) { - sc.setCheckpointDir(new Path(dir, "rdds-" + UUID.randomUUID.toString).toString) + sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(dir)) checkpointDir = dir checkpointInterval = interval } else { @@ -227,8 +227,11 @@ class StreamingContext ( } def doCheckpoint(currentTime: Time) { + val startTime = System.currentTimeMillis() graph.updateCheckpointData(currentTime) new Checkpoint(this, currentTime).save(checkpointDir) + val stopTime = System.currentTimeMillis() + logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms") } } @@ -247,5 +250,9 @@ object StreamingContext { prefix + "-" + time.milliseconds + "." + suffix } } + + def getSparkCheckpointDir(sscCheckpointDir: String): String = { + new Path(sscCheckpointDir, UUID.randomUUID.toString).toString + } } diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 9fdfd50be2..038827ddb0 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -52,15 +52,13 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { .checkpoint(stateStreamCheckpointInterval) .map(t => (t._1, t._2.self)) } - val ssc = setupStreams(input, operation) - val stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + var ssc = setupStreams(input, operation) + var stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head - // Run till a time such that at least one RDD in the stream should have been checkpointed + // Run till a time such that at least one RDD in the stream should have been checkpointed, + // then check whether some RDD has been checkpointed or not ssc.start() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - advanceClock(clock, firstNumBatches) - - // Check whether some RDD has been checkpointed or not + runStreamsWithRealDelay(ssc, firstNumBatches) logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]") assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before first failure") stateStream.checkpointData.foreach { @@ -73,42 +71,45 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run till a further time such that previous checkpoint files in the stream would be deleted // and check whether the earlier checkpoint files are deleted val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString)) - advanceClock(clock, secondNumBatches) + runStreamsWithRealDelay(ssc, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) + ssc.stop() // Restart stream computation using the checkpoint file and check whether // checkpointed RDDs have been restored or not - ssc.stop() - val sscNew = new StreamingContext(checkpointDir) - val stateStreamNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head - logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]") - assert(!stateStreamNew.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from first failure") + ssc = new StreamingContext(checkpointDir) + stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStream.generatedRDDs.mkString("\n") + "]") + assert(!stateStream.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from first failure") - // Run one batch to generate a new checkpoint file - sscNew.start() - val clockNew = sscNew.scheduler.clock.asInstanceOf[ManualClock] - advanceClock(clockNew, 1) - - // Check whether some RDD is present in the checkpoint data or not - assert(!stateStreamNew.checkpointData.isEmpty, "No checkpointed RDDs in state stream before second failure") + // Run one batch to generate a new checkpoint file and check whether some RDD + // is present in the checkpoint data or not + ssc.start() + runStreamsWithRealDelay(ssc, 1) + assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before second failure") stateStream.checkpointData.foreach { case (time, data) => { val file = new File(data.toString) - assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") + assert(file.exists(), + "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") } } + ssc.stop() // Restart stream computation from the new checkpoint file to see whether that file has // correct checkpoint data - sscNew.stop() - val sscNewNew = new StreamingContext(checkpointDir) - val stateStreamNewNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head - logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]") - assert(!stateStreamNewNew.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from second failure") - sscNewNew.start() - advanceClock(sscNewNew.scheduler.clock.asInstanceOf[ManualClock], 1) - sscNewNew.stop() + ssc = new StreamingContext(checkpointDir) + stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStream.generatedRDDs.mkString("\n") + "]") + assert(!stateStream.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from second failure") + + // Adjust manual clock time as if it is being restarted after a delay + System.setProperty("spark.streaming.manualClock.jump", (batchDuration.milliseconds * 7).toString) + ssc.start() + runStreamsWithRealDelay(ssc, 4) + ssc.stop() + System.clearProperty("spark.streaming.manualClock.jump") } test("map and reduceByKey") { @@ -123,10 +124,12 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { test("reduceByKeyAndWindowInv") { val n = 10 val w = 4 - val input = (1 to n).map(x => Seq("a")).toSeq + val input = (1 to n).map(_ => Seq("a")).toSeq val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { - st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, batchDuration * 4, batchDuration) + st.map(x => (x, 1)) + .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) + .checkpoint(Seconds(2)) } for (i <- Seq(2, 3, 4)) { testCheckpointedOperation(input, operation, output, i) @@ -184,7 +187,14 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) } - def advanceClock(clock: ManualClock, numBatches: Long) { + /** + * Advances the manual clock on the streaming scheduler by given number of batches. + * It also wait for the expected amount of time for each batch. + */ + + + def runStreamsWithRealDelay(ssc: StreamingContext, numBatches: Long) { + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) for (i <- 1 to numBatches.toInt) { clock.addToTime(batchDuration.milliseconds) @@ -193,4 +203,5 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { logInfo("Manual clock after advancing = " + clock.time) Thread.sleep(batchDuration.milliseconds) } + } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala new file mode 100644 index 0000000000..5b414117fc --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -0,0 +1,188 @@ +package spark.streaming + +import org.scalatest.BeforeAndAfter +import org.apache.commons.io.FileUtils +import java.io.File +import scala.runtime.RichInt +import scala.util.Random +import spark.streaming.StreamingContext._ +import collection.mutable.ArrayBuffer +import spark.Logging + +/** + * This testsuite tests master failures at random times while the stream is running using + * the real clock. + */ +class FailureSuite extends TestSuiteBase with BeforeAndAfter { + + before { + FileUtils.deleteDirectory(new File(checkpointDir)) + } + + after { + FailureSuite.reset() + FileUtils.deleteDirectory(new File(checkpointDir)) + } + + override def framework = "CheckpointSuite" + + override def batchDuration = Milliseconds(500) + + override def checkpointDir = "checkpoint" + + override def checkpointInterval = batchDuration + + test("multiple failures with updateStateByKey") { + val n = 30 + // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... + val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq + // Last output: [ (a, 465) ] for n=30 + val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) ) + + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(Seconds(2)) + .map(t => (t._1, t._2.self)) + } + + testOperationWithMultipleFailures(input, operation, lastOutput, n, n) + } + + test("multiple failures with reduceByKeyAndWindow") { + val n = 30 + val w = 100 + assert(w > n, "Window should be much larger than the number of input sets in this test") + // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... + val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq + // Last output: [ (a, 465) ] + val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) ) + + val operation = (st: DStream[String]) => { + st.map(x => (x, 1)) + .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) + .checkpoint(Seconds(2)) + } + + testOperationWithMultipleFailures(input, operation, lastOutput, n, n) + } + + + /** + * Tests stream operation with multiple master failures, and verifies whether the + * final set of output values is as expected or not. Checking the final value is + * proof that no intermediate data was lost due to master failures. + */ + def testOperationWithMultipleFailures[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + lastExpectedOutput: Seq[V], + numBatches: Int, + numExpectedOutput: Int + ) { + var ssc = setupStreams[U, V](input, operation) + val mergedOutput = new ArrayBuffer[Seq[V]]() + + var totalTimeRan = 0L + while(totalTimeRan <= numBatches * batchDuration.milliseconds * 2) { + new KillingThread(ssc, numBatches * batchDuration.milliseconds.toInt / 4).start() + val (output, timeRan) = runStreamsWithRealClock[V](ssc, numBatches, numExpectedOutput) + + mergedOutput ++= output + totalTimeRan += timeRan + logInfo("New output = " + output) + logInfo("Merged output = " + mergedOutput) + logInfo("Total time spent = " + totalTimeRan) + val sleepTime = Random.nextInt(numBatches * batchDuration.milliseconds.toInt / 8) + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation in " + sleepTime + " ms " + + "\n-------------------------------------------\n" + ) + Thread.sleep(sleepTime) + FailureSuite.failed = false + ssc = new StreamingContext(checkpointDir) + } + ssc.stop() + ssc = null + + // Verify whether the last output is the expected one + val lastOutput = mergedOutput(mergedOutput.lastIndexWhere(!_.isEmpty)) + assert(lastOutput.toSet === lastExpectedOutput.toSet) + logInfo("Finished computation after " + FailureSuite.failureCount + " failures") + } + + /** + * Runs the streams set up in `ssc` on real clock until the expected max number of + */ + def runStreamsWithRealClock[V: ClassManifest]( + ssc: StreamingContext, + numBatches: Int, + maxExpectedOutput: Int + ): (Seq[Seq[V]], Long) = { + + System.clearProperty("spark.streaming.clock") + + assert(numBatches > 0, "Number of batches to run stream computation is zero") + assert(maxExpectedOutput > 0, "Max expected outputs after " + numBatches + " is zero") + logInfo("numBatches = " + numBatches + ", maxExpectedOutput = " + maxExpectedOutput) + + // Get the output buffer + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + val output = outputStream.output + val waitTime = (batchDuration.millis * (numBatches.toDouble + 0.5)).toLong + val startTime = System.currentTimeMillis() + + try { + // Start computation + ssc.start() + + // Wait until expected number of output items have been generated + while (output.size < maxExpectedOutput && System.currentTimeMillis() - startTime < waitTime && !FailureSuite.failed) { + logInfo("output.size = " + output.size + ", maxExpectedOutput = " + maxExpectedOutput) + Thread.sleep(100) + } + } catch { + case e: Exception => logInfo("Exception while running streams: " + e) + } finally { + ssc.stop() + } + val timeTaken = System.currentTimeMillis() - startTime + logInfo("" + output.size + " sets of output generated in " + timeTaken + " ms") + (output, timeTaken) + } + + +} + +object FailureSuite { + var failed = false + var failureCount = 0 + + def reset() { + failed = false + failureCount = 0 + } +} + +class KillingThread(ssc: StreamingContext, maxKillWaitTime: Int) extends Thread with Logging { + initLogging() + + override def run() { + var minKillWaitTime = if (FailureSuite.failureCount == 0) 3000 else 1000 // to allow the first checkpoint + val killWaitTime = minKillWaitTime + Random.nextInt(maxKillWaitTime) + logInfo("Kill wait time = " + killWaitTime) + Thread.sleep(killWaitTime.toLong) + logInfo( + "\n---------------------------------------\n" + + "Killing streaming context after " + killWaitTime + " ms" + + "\n---------------------------------------\n" + ) + if (ssc != null) ssc.stop() + FailureSuite.failed = true + FailureSuite.failureCount += 1 + } +} diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index b8c7f99603..5fb5cc504c 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -23,12 +23,9 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ def compute(validTime: Time): Option[RDD[T]] = { logInfo("Computing RDD for time " + validTime) val index = ((validTime - zeroTime) / slideTime - 1).toInt - val rdd = if (index < input.size) { - ssc.sc.makeRDD(input(index), numPartitions) - } else { - ssc.sc.makeRDD(Seq[T](), numPartitions) - } - logInfo("Created RDD " + rdd.id) + val selectedInput = if (index < input.size) input(index) else Seq[T]() + val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) + logInfo("Created RDD " + rdd.id + " with " + selectedInput) Some(rdd) } } -- cgit v1.2.3 From cc2a65f54715ff0990d5873d50eec0dedf64d409 Mon Sep 17 00:00:00 2001 From: tdas Date: Thu, 8 Nov 2012 11:17:57 +0000 Subject: Fixed bug in InputStreamsSuite --- streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index c17254b809..8f892baab1 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -12,6 +12,8 @@ import org.apache.commons.io.FileUtils class InputStreamsSuite extends TestSuiteBase { + + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") test("network input stream") { // Start the server -- cgit v1.2.3 From 52d21cb682d1c4ca05e6823f8049ccedc3c5530c Mon Sep 17 00:00:00 2001 From: tdas Date: Thu, 8 Nov 2012 11:35:40 +0000 Subject: Removed unnecessary files. --- .../spark/streaming/util/ConnectionHandler.scala | 157 -------- .../spark/streaming/util/SenderReceiverTest.scala | 67 ---- .../streaming/util/SentenceFileGenerator.scala | 92 ----- .../scala/spark/streaming/util/ShuffleTest.scala | 23 -- .../scala/spark/streaming/util/TestGenerator.scala | 107 ------ .../spark/streaming/util/TestGenerator2.scala | 119 ------ .../spark/streaming/util/TestGenerator4.scala | 244 ------------ .../streaming/util/TestStreamCoordinator.scala | 39 -- .../spark/streaming/util/TestStreamReceiver3.scala | 421 --------------------- .../spark/streaming/util/TestStreamReceiver4.scala | 374 ------------------ 10 files changed, 1643 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/TestGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala diff --git a/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala b/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala deleted file mode 100644 index cde868a0c9..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/ConnectionHandler.scala +++ /dev/null @@ -1,157 +0,0 @@ -package spark.streaming.util - -import spark.Logging - -import scala.collection.mutable.{ArrayBuffer, SynchronizedQueue} - -import java.net._ -import java.io._ -import java.nio._ -import java.nio.charset._ -import java.nio.channels._ -import java.nio.channels.spi._ - -abstract class ConnectionHandler(host: String, port: Int, connect: Boolean) -extends Thread with Logging { - - val selector = SelectorProvider.provider.openSelector() - val interestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - - initLogging() - - override def run() { - try { - if (connect) { - connect() - } else { - listen() - } - - var interrupted = false - while(!interrupted) { - - preSelect() - - while(!interestChangeRequests.isEmpty) { - val (key, ops) = interestChangeRequests.dequeue - val lastOps = key.interestOps() - key.interestOps(ops) - - def intToOpStr(op: Int): String = { - val opStrs = new ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed ops from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - - selector.select() - interrupted = Thread.currentThread.isInterrupted - - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext) { - val key = selectedKeys.next.asInstanceOf[SelectionKey] - selectedKeys.remove() - if (key.isValid) { - if (key.isAcceptable) { - accept(key) - } else if (key.isConnectable) { - finishConnect(key) - } else if (key.isReadable) { - read(key) - } else if (key.isWritable) { - write(key) - } - } - } - } - } catch { - case e: Exception => { - logError("Error in select loop", e) - } - } - } - - def connect() { - val socketAddress = new InetSocketAddress(host, port) - val channel = SocketChannel.open() - channel.configureBlocking(false) - channel.socket.setReuseAddress(true) - channel.socket.setTcpNoDelay(true) - channel.connect(socketAddress) - channel.register(selector, SelectionKey.OP_CONNECT) - logInfo("Initiating connection to [" + socketAddress + "]") - } - - def listen() { - val channel = ServerSocketChannel.open() - channel.configureBlocking(false) - channel.socket.setReuseAddress(true) - channel.socket.setReceiveBufferSize(256 * 1024) - channel.socket.bind(new InetSocketAddress(port)) - channel.register(selector, SelectionKey.OP_ACCEPT) - logInfo("Listening on port " + port) - } - - def finishConnect(key: SelectionKey) { - try { - val channel = key.channel.asInstanceOf[SocketChannel] - val address = channel.socket.getRemoteSocketAddress - channel.finishConnect() - logInfo("Connected to [" + host + ":" + port + "]") - ready(key) - } catch { - case e: IOException => { - logError("Error finishing connect to " + host + ":" + port) - close(key) - } - } - } - - def accept(key: SelectionKey) { - try { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - val channel = serverChannel.accept() - val address = channel.socket.getRemoteSocketAddress - channel.configureBlocking(false) - logInfo("Accepted connection from [" + address + "]") - ready(channel.register(selector, 0)) - } catch { - case e: IOException => { - logError("Error accepting connection", e) - } - } - } - - def changeInterest(key: SelectionKey, ops: Int) { - logTrace("Added request to change ops to " + ops) - interestChangeRequests += ((key, ops)) - } - - def ready(key: SelectionKey) - - def preSelect() { - } - - def read(key: SelectionKey) { - throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) - } - - def write(key: SelectionKey) { - throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) - } - - def close(key: SelectionKey) { - try { - key.channel.close() - key.cancel() - Thread.currentThread.interrupt - } catch { - case e: Exception => logError("Error closing connection", e) - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala b/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala deleted file mode 100644 index 3922dfbad6..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/SenderReceiverTest.scala +++ /dev/null @@ -1,67 +0,0 @@ -package spark.streaming.util - -import java.net.{Socket, ServerSocket} -import java.io.{ByteArrayOutputStream, DataOutputStream, DataInputStream, BufferedInputStream} - -object Receiver { - def main(args: Array[String]) { - val port = args(0).toInt - val lsocket = new ServerSocket(port) - println("Listening on port " + port ) - while(true) { - val socket = lsocket.accept() - (new Thread() { - override def run() { - val buffer = new Array[Byte](100000) - var count = 0 - val time = System.currentTimeMillis - try { - val is = new DataInputStream(new BufferedInputStream(socket.getInputStream)) - var loop = true - var string: String = null - do { - string = is.readUTF() - if (string != null) { - count += 28 - } - } while (string != null) - } catch { - case e: Exception => e.printStackTrace() - } - val timeTaken = System.currentTimeMillis - time - val tput = (count / 1024.0) / (timeTaken / 1000.0) - println("Data = " + count + " bytes\nTime = " + timeTaken + " ms\nTput = " + tput + " KB/s") - } - }).start() - } - } - -} - -object Sender { - - def main(args: Array[String]) { - try { - val host = args(0) - val port = args(1).toInt - val size = args(2).toInt - - val byteStream = new ByteArrayOutputStream() - val stringDataStream = new DataOutputStream(byteStream) - (0 until size).foreach(_ => stringDataStream.writeUTF("abcdedfghijklmnopqrstuvwxy")) - val bytes = byteStream.toByteArray() - println("Generated array of " + bytes.length + " bytes") - - /*val bytes = new Array[Byte](size)*/ - val socket = new Socket(host, port) - val os = socket.getOutputStream - os.write(bytes) - os.flush - socket.close() - - } catch { - case e: Exception => e.printStackTrace - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala b/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala deleted file mode 100644 index 94e8f7a849..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/SentenceFileGenerator.scala +++ /dev/null @@ -1,92 +0,0 @@ -package spark.streaming.util - -import spark._ - -import scala.collection.mutable.ArrayBuffer -import scala.util.Random -import scala.io.Source - -import java.net.InetSocketAddress - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -object SentenceFileGenerator { - - def printUsage () { - println ("Usage: SentenceFileGenerator <# partitions> []") - System.exit(0) - } - - def main (args: Array[String]) { - if (args.length < 4) { - printUsage - } - - val master = args(0) - val fs = new Path(args(1)).getFileSystem(new Configuration()) - val targetDirectory = new Path(args(1)).makeQualified(fs) - val numPartitions = args(2).toInt - val sentenceFile = args(3) - val sentencesPerSecond = { - if (args.length > 4) args(4).toInt - else 10 - } - - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n").toArray - source.close () - println("Read " + lines.length + " lines from file " + sentenceFile) - - val sentences = { - val buffer = ArrayBuffer[String]() - val random = new Random() - var i = 0 - while (i < sentencesPerSecond) { - buffer += lines(random.nextInt(lines.length)) - i += 1 - } - buffer.toArray - } - println("Generated " + sentences.length + " sentences") - - val sc = new SparkContext(master, "SentenceFileGenerator") - val sentencesRDD = sc.parallelize(sentences, numPartitions) - - val tempDirectory = new Path(targetDirectory, "_tmp") - - fs.mkdirs(targetDirectory) - fs.mkdirs(tempDirectory) - - var saveTimeMillis = System.currentTimeMillis - try { - while (true) { - val newDir = new Path(targetDirectory, "Sentences-" + saveTimeMillis) - val tmpNewDir = new Path(tempDirectory, "Sentences-" + saveTimeMillis) - println("Writing to file " + newDir) - sentencesRDD.saveAsTextFile(tmpNewDir.toString) - fs.rename(tmpNewDir, newDir) - saveTimeMillis += 1000 - val sleepTimeMillis = { - val currentTimeMillis = System.currentTimeMillis - if (saveTimeMillis < currentTimeMillis) { - 0 - } else { - saveTimeMillis - currentTimeMillis - } - } - println("Sleeping for " + sleepTimeMillis + " ms") - Thread.sleep(sleepTimeMillis) - } - } catch { - case e: Exception => - } - } -} - - - - diff --git a/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala b/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala deleted file mode 100644 index 60085f4f88..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/ShuffleTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -package spark.streaming.util - -import spark.SparkContext -import SparkContext._ - -object ShuffleTest { - def main(args: Array[String]) { - - if (args.length < 1) { - println ("Usage: ShuffleTest ") - System.exit(1) - } - - val sc = new spark.SparkContext(args(0), "ShuffleTest") - val rdd = sc.parallelize(1 to 1000, 500).cache - - def time(f: => Unit) { val start = System.nanoTime; f; println((System.nanoTime - start) * 1.0e-6) } - - time { for (i <- 0 until 50) time { rdd.map(x => (x % 100, x)).reduceByKey(_ + _, 10).count } } - System.exit(0) - } -} - diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala deleted file mode 100644 index 23e9235c60..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/TestGenerator.scala +++ /dev/null @@ -1,107 +0,0 @@ -package spark.streaming.util - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.net.InetSocketAddress - - -object TestGenerator { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - /* - def generateRandomSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - val sleepBetweenSentences = 1000.0 / sentencesPerSecond.toDouble - 1 - val random = new Random () - - try { - var lastPrintTime = System.currentTimeMillis() - var count = 0 - while(true) { - streamReceiver ! lines(random.nextInt(lines.length)) - count += 1 - if (System.currentTimeMillis - lastPrintTime >= 1000) { - println (count + " sentences sent last second") - count = 0 - lastPrintTime = System.currentTimeMillis - } - Thread.sleep(sleepBetweenSentences.toLong) - } - } catch { - case e: Exception => - } - }*/ - - def generateSameSentences(lines: Array[String], sentencesPerSecond: Int, streamReceiver: AbstractActor) { - try { - val numSentences = if (sentencesPerSecond <= 0) { - lines.length - } else { - sentencesPerSecond - } - val sentences = lines.take(numSentences).toArray - - var nextSendingTime = System.currentTimeMillis() - val sendAsArray = true - while(true) { - if (sendAsArray) { - println("Sending as array") - streamReceiver !? sentences - } else { - println("Sending individually") - sentences.foreach(sentence => { - streamReceiver !? sentence - }) - } - println ("Sent " + numSentences + " sentences in " + (System.currentTimeMillis - nextSendingTime) + " ms") - nextSendingTime += 1000 - val sleepTime = nextSendingTime - System.currentTimeMillis - if (sleepTime > 0) { - println ("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } - } - } catch { - case e: Exception => - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - printUsage - } - - val generateRandomly = false - - val streamReceiverIP = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val sentencesPerSecond = if (args.length > 3) args(3).toInt else 10 - val sentenceInputName = if (args.length > 4) args(4) else "Sentences" - - println("Sending " + sentencesPerSecond + " sentences per second to " + - streamReceiverIP + ":" + streamReceiverPort + "/NetworkStreamReceiver-" + sentenceInputName) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close () - - val streamReceiver = select( - Node(streamReceiverIP, streamReceiverPort), - Symbol("NetworkStreamReceiver-" + sentenceInputName)) - if (generateRandomly) { - /*generateRandomSentences(lines, sentencesPerSecond, streamReceiver)*/ - } else { - generateSameSentences(lines, sentencesPerSecond, streamReceiver) - } - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala deleted file mode 100644 index ff840d084f..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/TestGenerator2.scala +++ /dev/null @@ -1,119 +0,0 @@ -package spark.streaming.util - -import scala.util.Random -import scala.io.Source -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ - -import java.io.{DataOutputStream, ByteArrayOutputStream, DataInputStream} -import java.net.Socket - -object TestGenerator2 { - - def printUsage { - println ("Usage: SentenceGenerator []") - System.exit(0) - } - - def sendSentences(streamReceiverHost: String, streamReceiverPort: Int, numSentences: Int, bytes: Array[Byte], intervalTime: Long){ - try { - println("Connecting to " + streamReceiverHost + ":" + streamReceiverPort) - val socket = new Socket(streamReceiverHost, streamReceiverPort) - - println("Sending " + numSentences+ " sentences / " + (bytes.length / 1024.0 / 1024.0) + " MB per " + intervalTime + " ms to " + streamReceiverHost + ":" + streamReceiverPort ) - val currentTime = System.currentTimeMillis - var targetTime = (currentTime / intervalTime + 1).toLong * intervalTime - Thread.sleep(targetTime - currentTime) - - while(true) { - val startTime = System.currentTimeMillis() - println("Sending at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") - val socketOutputStream = socket.getOutputStream - val parts = 10 - (0 until parts).foreach(i => { - val partStartTime = System.currentTimeMillis - - val offset = (i * bytes.length / parts).toInt - val len = math.min(((i + 1) * bytes.length / parts).toInt - offset, bytes.length) - socketOutputStream.write(bytes, offset, len) - socketOutputStream.flush() - val partFinishTime = System.currentTimeMillis - println("Sending part " + i + " of " + len + " bytes took " + (partFinishTime - partStartTime) + " ms") - val sleepTime = math.max(0, 1000 / parts - (partFinishTime - partStartTime) - 1) - Thread.sleep(sleepTime) - }) - - socketOutputStream.flush() - /*val socketInputStream = new DataInputStream(socket.getInputStream)*/ - /*val reply = socketInputStream.readUTF()*/ - val finishTime = System.currentTimeMillis() - println ("Sent " + bytes.length + " bytes in " + (finishTime - startTime) + " ms for interval [" + targetTime + ", " + (targetTime + intervalTime) + "]") - /*println("Received = " + reply)*/ - targetTime = targetTime + intervalTime - val sleepTime = (targetTime - finishTime) + 10 - if (sleepTime > 0) { - println("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } else { - println("############################") - println("###### Skipping sleep ######") - println("############################") - } - } - } catch { - case e: Exception => println(e) - } - println("Stopped sending") - } - - def main(args: Array[String]) { - if (args.length < 4) { - printUsage - } - - val streamReceiverHost = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val intervalTime = args(3).toLong - val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 - - println("Reading the file " + sentenceFile) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close() - - val numSentences = if (sentencesPerInterval <= 0) { - lines.length - } else { - sentencesPerInterval - } - - println("Generating sentences") - val sentences: Array[String] = if (numSentences <= lines.length) { - lines.take(numSentences).toArray - } else { - (0 until numSentences).map(i => lines(i % lines.length)).toArray - } - - println("Converting to byte array") - val byteStream = new ByteArrayOutputStream() - val stringDataStream = new DataOutputStream(byteStream) - /*stringDataStream.writeInt(sentences.size)*/ - sentences.foreach(stringDataStream.writeUTF) - val bytes = byteStream.toByteArray() - stringDataStream.close() - println("Generated array of " + bytes.length + " bytes") - - /*while(true) { */ - sendSentences(streamReceiverHost, streamReceiverPort, numSentences, bytes, intervalTime) - /*println("Sleeping for 5 seconds")*/ - /*Thread.sleep(5000)*/ - /*System.gc()*/ - /*}*/ - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala b/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala deleted file mode 100644 index 9c39ef3e12..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/TestGenerator4.scala +++ /dev/null @@ -1,244 +0,0 @@ -package spark.streaming.util - -import spark.Logging - -import scala.util.Random -import scala.io.Source -import scala.collection.mutable.{ArrayBuffer, Queue} - -import java.net._ -import java.io._ -import java.nio._ -import java.nio.charset._ -import java.nio.channels._ - -import it.unimi.dsi.fastutil.io._ - -class TestGenerator4(targetHost: String, targetPort: Int, sentenceFile: String, intervalDuration: Long, sentencesPerInterval: Int) -extends Logging { - - class SendingConnectionHandler(host: String, port: Int, generator: TestGenerator4) - extends ConnectionHandler(host, port, true) { - - val buffers = new ArrayBuffer[ByteBuffer] - val newBuffers = new Queue[ByteBuffer] - var activeKey: SelectionKey = null - - def send(buffer: ByteBuffer) { - logDebug("Sending: " + buffer) - newBuffers.synchronized { - newBuffers.enqueue(buffer) - } - selector.wakeup() - buffer.synchronized { - buffer.wait() - } - } - - override def ready(key: SelectionKey) { - logDebug("Ready") - activeKey = key - val channel = key.channel.asInstanceOf[SocketChannel] - channel.register(selector, SelectionKey.OP_WRITE) - generator.startSending() - } - - override def preSelect() { - newBuffers.synchronized { - while(!newBuffers.isEmpty) { - val buffer = newBuffers.dequeue - buffers += buffer - logDebug("Added: " + buffer) - changeInterest(activeKey, SelectionKey.OP_WRITE) - } - } - } - - override def write(key: SelectionKey) { - try { - /*while(true) {*/ - val channel = key.channel.asInstanceOf[SocketChannel] - if (buffers.size > 0) { - val buffer = buffers(0) - val newBuffer = buffer.slice() - newBuffer.limit(math.min(newBuffer.remaining, 32768)) - val bytesWritten = channel.write(newBuffer) - buffer.position(buffer.position + bytesWritten) - if (bytesWritten == 0) return - if (buffer.remaining == 0) { - buffers -= buffer - buffer.synchronized { - buffer.notify() - } - } - /*changeInterest(key, SelectionKey.OP_WRITE)*/ - } else { - changeInterest(key, 0) - } - /*}*/ - } catch { - case e: IOException => { - if (e.toString.contains("pipe") || e.toString.contains("reset")) { - logError("Connection broken") - } else { - logError("Connection error", e) - } - close(key) - } - } - } - - override def close(key: SelectionKey) { - buffers.clear() - super.close(key) - } - } - - initLogging() - - val connectionHandler = new SendingConnectionHandler(targetHost, targetPort, this) - var sendingThread: Thread = null - var sendCount = 0 - val sendBatches = 5 - - def run() { - logInfo("Connection handler started") - connectionHandler.start() - connectionHandler.join() - if (sendingThread != null && !sendingThread.isInterrupted) { - sendingThread.interrupt - } - logInfo("Connection handler stopped") - } - - def startSending() { - sendingThread = new Thread() { - override def run() { - logInfo("STARTING TO SEND") - sendSentences() - logInfo("SENDING STOPPED AFTER " + sendCount) - connectionHandler.interrupt() - } - } - sendingThread.start() - } - - def stopSending() { - sendingThread.interrupt() - } - - def sendSentences() { - logInfo("Reading the file " + sentenceFile) - val source = Source.fromFile(sentenceFile) - val lines = source.mkString.split ("\n") - source.close() - - val numSentences = if (sentencesPerInterval <= 0) { - lines.length - } else { - sentencesPerInterval - } - - logInfo("Generating sentence buffer") - val sentences: Array[String] = if (numSentences <= lines.length) { - lines.take(numSentences).toArray - } else { - (0 until numSentences).map(i => lines(i % lines.length)).toArray - } - - /* - val sentences: Array[String] = if (numSentences <= lines.length) { - lines.take((numSentences / sendBatches).toInt).toArray - } else { - (0 until (numSentences/sendBatches)).map(i => lines(i % lines.length)).toArray - }*/ - - - val serializer = new spark.KryoSerializer().newInstance() - val byteStream = new FastByteArrayOutputStream(100 * 1024 * 1024) - serializer.serializeStream(byteStream).writeAll(sentences.toIterator.asInstanceOf[Iterator[Any]]).close() - byteStream.trim() - val sentenceBuffer = ByteBuffer.wrap(byteStream.array) - - logInfo("Sending " + numSentences+ " sentences / " + sentenceBuffer.limit + " bytes per " + intervalDuration + " ms to " + targetHost + ":" + targetPort ) - val currentTime = System.currentTimeMillis - var targetTime = (currentTime / intervalDuration + 1).toLong * intervalDuration - Thread.sleep(targetTime - currentTime) - - val totalBytes = sentenceBuffer.limit - - while(true) { - val batchesInCurrentInterval = sendBatches // if (sendCount < 10) 1 else sendBatches - - val startTime = System.currentTimeMillis() - logDebug("Sending # " + sendCount + " at " + startTime + " ms with delay of " + (startTime - targetTime) + " ms") - - (0 until batchesInCurrentInterval).foreach(i => { - try { - val position = (i * totalBytes / sendBatches).toInt - val limit = if (i == sendBatches - 1) { - totalBytes - } else { - ((i + 1) * totalBytes / sendBatches).toInt - 1 - } - - val partStartTime = System.currentTimeMillis - sentenceBuffer.limit(limit) - connectionHandler.send(sentenceBuffer) - val partFinishTime = System.currentTimeMillis - val sleepTime = math.max(0, intervalDuration / sendBatches - (partFinishTime - partStartTime) - 1) - Thread.sleep(sleepTime) - - } catch { - case ie: InterruptedException => return - case e: Exception => e.printStackTrace() - } - }) - sentenceBuffer.rewind() - - val finishTime = System.currentTimeMillis() - /*logInfo ("Sent " + sentenceBuffer.limit + " bytes in " + (finishTime - startTime) + " ms")*/ - targetTime = targetTime + intervalDuration //+ (if (sendCount < 3) 1000 else 0) - - val sleepTime = (targetTime - finishTime) + 20 - if (sleepTime > 0) { - logInfo("Sleeping for " + sleepTime + " ms") - Thread.sleep(sleepTime) - } else { - logInfo("###### Skipping sleep ######") - } - if (Thread.currentThread.isInterrupted) { - return - } - sendCount += 1 - } - } -} - -object TestGenerator4 { - def printUsage { - println("Usage: TestGenerator4 []") - System.exit(0) - } - - def main(args: Array[String]) { - println("GENERATOR STARTED") - if (args.length < 4) { - printUsage - } - - - val streamReceiverHost = args(0) - val streamReceiverPort = args(1).toInt - val sentenceFile = args(2) - val intervalDuration = args(3).toLong - val sentencesPerInterval = if (args.length > 4) args(4).toInt else 0 - - while(true) { - val generator = new TestGenerator4(streamReceiverHost, streamReceiverPort, sentenceFile, intervalDuration, sentencesPerInterval) - generator.run() - Thread.sleep(2000) - } - println("GENERATOR STOPPED") - } -} diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala deleted file mode 100644 index f584f772bb..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/TestStreamCoordinator.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.streaming.util - -import spark.streaming._ -import spark.Logging - -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ - -sealed trait TestStreamCoordinatorMessage -case class GetStreamDetails extends TestStreamCoordinatorMessage -case class GotStreamDetails(name: String, duration: Long) extends TestStreamCoordinatorMessage -case class TestStarted extends TestStreamCoordinatorMessage - -class TestStreamCoordinator(streamDetails: Array[(String, Long)]) extends Actor with Logging { - - var index = 0 - - initLogging() - - logInfo("Created") - - def receive = { - case TestStarted => { - sender ! "OK" - } - - case GetStreamDetails => { - val streamDetail = if (index >= streamDetails.length) null else streamDetails(index) - sender ! GotStreamDetails(streamDetail._1, streamDetail._2) - index += 1 - if (streamDetail != null) { - logInfo("Allocated " + streamDetail._1 + " (" + index + "/" + streamDetails.length + ")" ) - } - } - } - -} - diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala deleted file mode 100644 index 80ad924dd8..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver3.scala +++ /dev/null @@ -1,421 +0,0 @@ -package spark.streaming.util - -import spark._ -import spark.storage._ -import spark.util.AkkaUtils -import spark.streaming._ - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} - -import akka.actor._ -import akka.actor.Actor -import akka.dispatch._ -import akka.pattern.ask -import akka.util.duration._ - -import java.io.DataInputStream -import java.io.BufferedInputStream -import java.net.Socket -import java.net.ServerSocket -import java.util.LinkedHashMap - -import org.apache.hadoop.fs._ -import org.apache.hadoop.conf._ -import org.apache.hadoop.io._ -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util._ - -import spark.Utils - - -class TestStreamReceiver3(actorSystem: ActorSystem, blockManager: BlockManager) -extends Thread with Logging { - - - class DataHandler( - inputName: String, - longIntervalDuration: Time, - shortIntervalDuration: Time, - blockManager: BlockManager - ) - extends Logging { - - class Block(var id: String, var shortInterval: Interval) { - val data = ArrayBuffer[String]() - var pushed = false - def longInterval = getLongInterval(shortInterval) - def empty() = (data.size == 0) - def += (str: String) = (data += str) - override def toString() = "Block " + id - } - - class Bucket(val longInterval: Interval) { - val blocks = new ArrayBuffer[Block]() - var filled = false - def += (block: Block) = blocks += block - def empty() = (blocks.size == 0) - def ready() = (filled && !blocks.exists(! _.pushed)) - def blockIds() = blocks.map(_.id).toArray - override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" - } - - initLogging() - - val shortIntervalDurationMillis = shortIntervalDuration.toLong - val longIntervalDurationMillis = longIntervalDuration.toLong - - var currentBlock: Block = null - var currentBucket: Bucket = null - - val blocksForPushing = new Queue[Block]() - val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] - - val blockUpdatingThread = new Thread() { override def run() { keepUpdatingCurrentBlock() } } - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - def start() { - blockUpdatingThread.start() - blockPushingThread.start() - } - - def += (data: String) = addData(data) - - def addData(data: String) { - if (currentBlock == null) { - updateCurrentBlock() - } - currentBlock.synchronized { - currentBlock += data - } - } - - def getShortInterval(time: Time): Interval = { - val intervalBegin = time.floor(shortIntervalDuration) - Interval(intervalBegin, intervalBegin + shortIntervalDuration) - } - - def getLongInterval(shortInterval: Interval): Interval = { - val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) - Interval(intervalBegin, intervalBegin + longIntervalDuration) - } - - def updateCurrentBlock() { - /*logInfo("Updating current block")*/ - val currentTime = Time(System.currentTimeMillis) - val shortInterval = getShortInterval(currentTime) - val longInterval = getLongInterval(shortInterval) - - def createBlock(reuseCurrentBlock: Boolean = false) { - val newBlockId = inputName + "-" + longInterval.toFormattedString + "-" + currentBucket.blocks.size - if (!reuseCurrentBlock) { - val newBlock = new Block(newBlockId, shortInterval) - /*logInfo("Created " + currentBlock)*/ - currentBlock = newBlock - } else { - currentBlock.shortInterval = shortInterval - currentBlock.id = newBlockId - } - } - - def createBucket() { - val newBucket = new Bucket(longInterval) - buckets += ((longInterval, newBucket)) - currentBucket = newBucket - /*logInfo("Created " + currentBucket + ", " + buckets.size + " buckets")*/ - } - - if (currentBlock == null || currentBucket == null) { - createBucket() - currentBucket.synchronized { - createBlock() - } - return - } - - currentBlock.synchronized { - var reuseCurrentBlock = false - - if (shortInterval != currentBlock.shortInterval) { - if (!currentBlock.empty) { - blocksForPushing.synchronized { - blocksForPushing += currentBlock - blocksForPushing.notifyAll() - } - } - - currentBucket.synchronized { - if (currentBlock.empty) { - reuseCurrentBlock = true - } else { - currentBucket += currentBlock - } - - if (longInterval != currentBucket.longInterval) { - currentBucket.filled = true - if (currentBucket.ready) { - currentBucket.notifyAll() - } - createBucket() - } - } - - createBlock(reuseCurrentBlock) - } - } - } - - def pushBlock(block: Block) { - try{ - if (blockManager != null) { - logInfo("Pushing block") - val startTime = System.currentTimeMillis - - val bytes = blockManager.dataSerialize("rdd_", block.data.toIterator) // TODO: Will this be an RDD block? - val finishTime = System.currentTimeMillis - logInfo(block + " serialization delay is " + (finishTime - startTime) / 1000.0 + " s") - - blockManager.putBytes(block.id.toString, bytes, StorageLevel.MEMORY_AND_DISK_SER_2) - /*blockManager.putBytes(block.id.toString, bytes, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ - /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY_DESER)*/ - /*blockManager.put(block.id.toString, block.data.toIterator, StorageLevel.DISK_AND_MEMORY)*/ - val finishTime1 = System.currentTimeMillis - logInfo(block + " put delay is " + (finishTime1 - startTime) / 1000.0 + " s") - } else { - logWarning(block + " not put as block manager is null") - } - } catch { - case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) - } - } - - def getBucket(longInterval: Interval): Option[Bucket] = { - buckets.get(longInterval) - } - - def clearBucket(longInterval: Interval) { - buckets.remove(longInterval) - } - - def keepUpdatingCurrentBlock() { - logInfo("Thread to update current block started") - while(true) { - updateCurrentBlock() - val currentTimeMillis = System.currentTimeMillis - val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * - shortIntervalDurationMillis - currentTimeMillis + 1 - Thread.sleep(sleepTimeMillis) - } - } - - def keepPushingBlocks() { - var loop = true - logInfo("Thread to push blocks started") - while(loop) { - val block = blocksForPushing.synchronized { - if (blocksForPushing.size == 0) { - blocksForPushing.wait() - } - blocksForPushing.dequeue - } - pushBlock(block) - block.pushed = true - block.data.clear() - - val bucket = buckets(block.longInterval) - bucket.synchronized { - if (bucket.ready) { - bucket.notifyAll() - } - } - } - } - } - - - class ConnectionListener(port: Int, dataHandler: DataHandler) - extends Thread with Logging { - initLogging() - override def run { - try { - val listener = new ServerSocket(port) - logInfo("Listening on port " + port) - while (true) { - new ConnectionHandler(listener.accept(), dataHandler).start(); - } - listener.close() - } catch { - case e: Exception => logError("", e); - } - } - } - - class ConnectionHandler(socket: Socket, dataHandler: DataHandler) extends Thread with Logging { - initLogging() - override def run { - logInfo("New connection from " + socket.getInetAddress() + ":" + socket.getPort) - val bytes = new Array[Byte](100 * 1024 * 1024) - try { - - val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, 1024 * 1024)) - /*val inputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream))*/ - var str: String = null - str = inputStream.readUTF - while(str != null) { - dataHandler += str - str = inputStream.readUTF() - } - - /* - var loop = true - while(loop) { - val numRead = inputStream.read(bytes) - if (numRead < 0) { - loop = false - } - inbox += ((LongTime(SystemTime.currentTimeMillis), "test")) - }*/ - - inputStream.close() - } catch { - case e => logError("Error receiving data", e) - } - socket.close() - } - } - - initLogging() - - val masterHost = System.getProperty("spark.master.host") - val masterPort = System.getProperty("spark.master.port").toInt - - val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) - val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") - val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") - - logInfo("Getting stream details from master " + masterHost + ":" + masterPort) - - val timeout = 50 millis - - var started = false - while (!started) { - askActor[String](testStreamCoordinator, TestStarted) match { - case Some(str) => { - started = true - logInfo("TestStreamCoordinator started") - } - case None => { - logInfo("TestStreamCoordinator not started yet") - Thread.sleep(200) - } - } - } - - val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { - case Some(details) => details - case None => throw new Exception("Could not get stream details") - } - logInfo("Stream details received: " + streamDetails) - - val inputName = streamDetails.name - val intervalDurationMillis = streamDetails.duration - val intervalDuration = Time(intervalDurationMillis) - - val dataHandler = new DataHandler( - inputName, - intervalDuration, - Time(TestStreamReceiver3.SHORT_INTERVAL_MILLIS), - blockManager) - - val connListener = new ConnectionListener(TestStreamReceiver3.PORT, dataHandler) - - // Send a message to an actor and return an option with its reply, or None if this times out - def askActor[T](actor: ActorRef, message: Any): Option[T] = { - try { - val future = actor.ask(message)(timeout) - return Some(Await.result(future, timeout).asInstanceOf[T]) - } catch { - case e: Exception => - logInfo("Error communicating with " + actor, e) - return None - } - } - - override def run() { - connListener.start() - dataHandler.start() - - var interval = Interval.currentInterval(intervalDuration) - var dataStarted = false - - while(true) { - waitFor(interval.endTime) - logInfo("Woken up at " + System.currentTimeMillis + " for " + interval) - dataHandler.getBucket(interval) match { - case Some(bucket) => { - logInfo("Found " + bucket + " for " + interval) - bucket.synchronized { - if (!bucket.ready) { - logInfo("Waiting for " + bucket) - bucket.wait() - logInfo("Wait over for " + bucket) - } - if (dataStarted || !bucket.empty) { - logInfo("Notifying " + bucket) - notifyScheduler(interval, bucket.blockIds) - dataStarted = true - } - bucket.blocks.clear() - dataHandler.clearBucket(interval) - } - } - case None => { - logInfo("Found none for " + interval) - if (dataStarted) { - logInfo("Notifying none") - notifyScheduler(interval, Array[String]()) - } - } - } - interval = interval.next - } - } - - def waitFor(time: Time) { - val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.milliseconds - if (currentTimeMillis < targetTimeMillis) { - val sleepTime = (targetTimeMillis - currentTimeMillis) - Thread.sleep(sleepTime + 1) - } - } - - def notifyScheduler(interval: Interval, blockIds: Array[String]) { - try { - sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime - val delay = (System.currentTimeMillis - time.milliseconds) / 1000.0 - logInfo("Pushing delay for " + time + " is " + delay + " s") - } catch { - case _ => logError("Exception notifying scheduler at interval " + interval) - } - } -} - -object TestStreamReceiver3 { - - val PORT = 9999 - val SHORT_INTERVAL_MILLIS = 100 - - def main(args: Array[String]) { - System.setProperty("spark.master.host", Utils.localHostName) - System.setProperty("spark.master.port", "7078") - val details = Array(("Sentences", 2000L)) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) - actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") - new TestStreamReceiver3(actorSystem, null).start() - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala b/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala deleted file mode 100644 index 31754870dd..0000000000 --- a/streaming/src/main/scala/spark/streaming/util/TestStreamReceiver4.scala +++ /dev/null @@ -1,374 +0,0 @@ -package spark.streaming.util - -import spark.streaming._ -import spark._ -import spark.storage._ -import spark.util.AkkaUtils - -import scala.math._ -import scala.collection.mutable.{Queue, HashMap, ArrayBuffer, SynchronizedMap} - -import java.io._ -import java.nio._ -import java.nio.charset._ -import java.nio.channels._ -import java.util.concurrent.Executors - -import akka.actor._ -import akka.actor.Actor -import akka.dispatch._ -import akka.pattern.ask -import akka.util.duration._ - -class TestStreamReceiver4(actorSystem: ActorSystem, blockManager: BlockManager) -extends Thread with Logging { - - class DataHandler( - inputName: String, - longIntervalDuration: Time, - shortIntervalDuration: Time, - blockManager: BlockManager - ) - extends Logging { - - class Block(val id: String, val shortInterval: Interval, val buffer: ByteBuffer) { - var pushed = false - def longInterval = getLongInterval(shortInterval) - override def toString() = "Block " + id - } - - class Bucket(val longInterval: Interval) { - val blocks = new ArrayBuffer[Block]() - var filled = false - def += (block: Block) = blocks += block - def empty() = (blocks.size == 0) - def ready() = (filled && !blocks.exists(! _.pushed)) - def blockIds() = blocks.map(_.id).toArray - override def toString() = "Bucket [" + longInterval + ", " + blocks.size + " blocks]" - } - - initLogging() - - val syncOnLastShortInterval = true - - val shortIntervalDurationMillis = shortIntervalDuration.milliseconds - val longIntervalDurationMillis = longIntervalDuration.milliseconds - - val buffer = ByteBuffer.allocateDirect(100 * 1024 * 1024) - var currentShortInterval = Interval.currentInterval(shortIntervalDuration) - - val blocksForPushing = new Queue[Block]() - val buckets = new HashMap[Interval, Bucket]() with SynchronizedMap[Interval, Bucket] - - val bufferProcessingThread = new Thread() { override def run() { keepProcessingBuffers() } } - val blockPushingExecutor = Executors.newFixedThreadPool(5) - - - def start() { - buffer.clear() - if (buffer.remaining == 0) { - throw new Exception("Buffer initialization error") - } - bufferProcessingThread.start() - } - - def readDataToBuffer(func: ByteBuffer => Int): Int = { - buffer.synchronized { - if (buffer.remaining == 0) { - logInfo("Received first data for interval " + currentShortInterval) - } - func(buffer) - } - } - - def getLongInterval(shortInterval: Interval): Interval = { - val intervalBegin = shortInterval.beginTime.floor(longIntervalDuration) - Interval(intervalBegin, intervalBegin + longIntervalDuration) - } - - def processBuffer() { - - def readInt(buffer: ByteBuffer): Int = { - var offset = 0 - var result = 0 - while (offset < 32) { - val b = buffer.get() - result |= ((b & 0x7F) << offset) - if ((b & 0x80) == 0) { - return result - } - offset += 7 - } - throw new Exception("Malformed zigzag-encoded integer") - } - - val currentLongInterval = getLongInterval(currentShortInterval) - val startTime = System.currentTimeMillis - val newBuffer: ByteBuffer = buffer.synchronized { - buffer.flip() - if (buffer.remaining == 0) { - buffer.clear() - null - } else { - logDebug("Processing interval " + currentShortInterval + " with delay of " + (System.currentTimeMillis - startTime) + " ms") - val startTime1 = System.currentTimeMillis - var loop = true - var count = 0 - while(loop) { - buffer.mark() - try { - val len = readInt(buffer) - buffer.position(buffer.position + len) - count += 1 - } catch { - case e: Exception => { - buffer.reset() - loop = false - } - } - } - val bytesToCopy = buffer.position - val newBuf = ByteBuffer.allocate(bytesToCopy) - buffer.position(0) - newBuf.put(buffer.slice().limit(bytesToCopy).asInstanceOf[ByteBuffer]) - newBuf.flip() - buffer.position(bytesToCopy) - buffer.compact() - newBuf - } - } - - if (newBuffer != null) { - val bucket = buckets.getOrElseUpdate(currentLongInterval, new Bucket(currentLongInterval)) - bucket.synchronized { - val newBlockId = inputName + "-" + currentLongInterval.toFormattedString + "-" + currentShortInterval.toFormattedString - val newBlock = new Block(newBlockId, currentShortInterval, newBuffer) - if (syncOnLastShortInterval) { - bucket += newBlock - } - logDebug("Created " + newBlock + " with " + newBuffer.remaining + " bytes, creation delay is " + (System.currentTimeMillis - currentShortInterval.endTime.milliseconds) / 1000.0 + " s" ) - blockPushingExecutor.execute(new Runnable() { def run() { pushAndNotifyBlock(newBlock) } }) - } - } - - val newShortInterval = Interval.currentInterval(shortIntervalDuration) - val newLongInterval = getLongInterval(newShortInterval) - - if (newLongInterval != currentLongInterval) { - buckets.get(currentLongInterval) match { - case Some(bucket) => { - bucket.synchronized { - bucket.filled = true - if (bucket.ready) { - bucket.notifyAll() - } - } - } - case None => - } - buckets += ((newLongInterval, new Bucket(newLongInterval))) - } - - currentShortInterval = newShortInterval - } - - def pushBlock(block: Block) { - try{ - if (blockManager != null) { - val startTime = System.currentTimeMillis - logInfo(block + " put start delay is " + (startTime - block.shortInterval.endTime.milliseconds) + " ms") - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY)*/ - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_2)*/ - blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY_2) - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.MEMORY_ONLY)*/ - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER)*/ - /*blockManager.putBytes(block.id.toString, block.buffer, StorageLevel.DISK_AND_MEMORY_DESER_2)*/ - val finishTime = System.currentTimeMillis - logInfo(block + " put delay is " + (finishTime - startTime) + " ms") - } else { - logWarning(block + " not put as block manager is null") - } - } catch { - case e: Exception => logError("Exception writing " + block + " to blockmanager" , e) - } - } - - def getBucket(longInterval: Interval): Option[Bucket] = { - buckets.get(longInterval) - } - - def clearBucket(longInterval: Interval) { - buckets.remove(longInterval) - } - - def keepProcessingBuffers() { - logInfo("Thread to process buffers started") - while(true) { - processBuffer() - val currentTimeMillis = System.currentTimeMillis - val sleepTimeMillis = (currentTimeMillis / shortIntervalDurationMillis + 1) * - shortIntervalDurationMillis - currentTimeMillis + 1 - Thread.sleep(sleepTimeMillis) - } - } - - def pushAndNotifyBlock(block: Block) { - pushBlock(block) - block.pushed = true - val bucket = if (syncOnLastShortInterval) { - buckets(block.longInterval) - } else { - var longInterval = block.longInterval - while(!buckets.contains(longInterval)) { - logWarning("Skipping bucket of " + longInterval + " for " + block) - longInterval = longInterval.next - } - val chosenBucket = buckets(longInterval) - logDebug("Choosing bucket of " + longInterval + " for " + block) - chosenBucket += block - chosenBucket - } - - bucket.synchronized { - if (bucket.ready) { - bucket.notifyAll() - } - } - - } - } - - - class ReceivingConnectionHandler(host: String, port: Int, dataHandler: DataHandler) - extends ConnectionHandler(host, port, false) { - - override def ready(key: SelectionKey) { - changeInterest(key, SelectionKey.OP_READ) - } - - override def read(key: SelectionKey) { - try { - val channel = key.channel.asInstanceOf[SocketChannel] - val bytesRead = dataHandler.readDataToBuffer(channel.read) - if (bytesRead < 0) { - close(key) - } - } catch { - case e: IOException => { - logError("Error reading", e) - close(key) - } - } - } - } - - initLogging() - - val masterHost = System.getProperty("spark.master.host", "localhost") - val masterPort = System.getProperty("spark.master.port", "7078").toInt - - val akkaPath = "akka://spark@%s:%s/user/".format(masterHost, masterPort) - val sparkstreamScheduler = actorSystem.actorFor(akkaPath + "/SparkStreamScheduler") - val testStreamCoordinator = actorSystem.actorFor(akkaPath + "/TestStreamCoordinator") - - logInfo("Getting stream details from master " + masterHost + ":" + masterPort) - - val streamDetails = askActor[GotStreamDetails](testStreamCoordinator, GetStreamDetails) match { - case Some(details) => details - case None => throw new Exception("Could not get stream details") - } - logInfo("Stream details received: " + streamDetails) - - val inputName = streamDetails.name - val intervalDurationMillis = streamDetails.duration - val intervalDuration = Milliseconds(intervalDurationMillis) - val shortIntervalDuration = Milliseconds(System.getProperty("spark.stream.shortinterval", "500").toInt) - - val dataHandler = new DataHandler(inputName, intervalDuration, shortIntervalDuration, blockManager) - val connectionHandler = new ReceivingConnectionHandler("localhost", 9999, dataHandler) - - val timeout = 100 millis - - // Send a message to an actor and return an option with its reply, or None if this times out - def askActor[T](actor: ActorRef, message: Any): Option[T] = { - try { - val future = actor.ask(message)(timeout) - return Some(Await.result(future, timeout).asInstanceOf[T]) - } catch { - case e: Exception => - logInfo("Error communicating with " + actor, e) - return None - } - } - - override def run() { - connectionHandler.start() - dataHandler.start() - - var interval = Interval.currentInterval(intervalDuration) - var dataStarted = false - - - while(true) { - waitFor(interval.endTime) - /*logInfo("Woken up at " + System.currentTimeMillis + " for " + interval)*/ - dataHandler.getBucket(interval) match { - case Some(bucket) => { - logDebug("Found " + bucket + " for " + interval) - bucket.synchronized { - if (!bucket.ready) { - logDebug("Waiting for " + bucket) - bucket.wait() - logDebug("Wait over for " + bucket) - } - if (dataStarted || !bucket.empty) { - logDebug("Notifying " + bucket) - notifyScheduler(interval, bucket.blockIds) - dataStarted = true - } - bucket.blocks.clear() - dataHandler.clearBucket(interval) - } - } - case None => { - logDebug("Found none for " + interval) - if (dataStarted) { - logDebug("Notifying none") - notifyScheduler(interval, Array[String]()) - } - } - } - interval = interval.next - } - } - - def waitFor(time: Time) { - val currentTimeMillis = System.currentTimeMillis - val targetTimeMillis = time.milliseconds - if (currentTimeMillis < targetTimeMillis) { - val sleepTime = (targetTimeMillis - currentTimeMillis) - Thread.sleep(sleepTime + 1) - } - } - - def notifyScheduler(interval: Interval, blockIds: Array[String]) { - try { - sparkstreamScheduler ! InputGenerated(inputName, interval, blockIds.toArray) - val time = interval.endTime - val delay = (System.currentTimeMillis - time.milliseconds) - logInfo("Notification delay for " + time + " is " + delay + " ms") - } catch { - case e: Exception => logError("Exception notifying scheduler at interval " + interval + ": " + e) - } - } -} - - -object TestStreamReceiver4 { - def main(args: Array[String]) { - val details = Array(("Sentences", 2000L)) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localHostName, 7078) - actorSystem.actorOf(Props(new TestStreamCoordinator(details)), name = "TestStreamCoordinator") - new TestStreamReceiver4(actorSystem, null).start() - } -} -- cgit v1.2.3 From e5a09367870be757a0abb3e2ad7a53e74110b033 Mon Sep 17 00:00:00 2001 From: Denny Date: Fri, 9 Nov 2012 12:23:46 -0800 Subject: Kafka Stream. --- .../src/main/scala/spark/streaming/DStream.scala | 31 ++-- .../main/scala/spark/streaming/DataHandler.scala | 83 ++++++++++ .../spark/streaming/NetworkInputDStream.scala | 21 ++- .../spark/streaming/NetworkInputTracker.scala | 9 +- .../scala/spark/streaming/RawInputDStream.scala | 2 +- .../scala/spark/streaming/SocketInputDStream.scala | 70 +-------- .../scala/spark/streaming/StreamingContext.scala | 4 +- .../spark/streaming/examples/KafkaWordCount.scala | 13 +- .../spark/streaming/input/KafkaInputDStream.scala | 173 +++++++++++++-------- 9 files changed, 245 insertions(+), 161 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/DataHandler.scala diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 922ff5088d..f891730317 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -17,6 +17,8 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration +case class DStreamCheckpointData(rdds: HashMap[Time, Any]) + abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -59,7 +61,7 @@ extends Serializable with Logging { // Checkpoint details protected[streaming] val mustCheckpoint = false protected[streaming] var checkpointInterval: Time = null - protected[streaming] val checkpointData = new HashMap[Time, Any]() + protected[streaming] var checkpointData = DStreamCheckpointData(HashMap[Time, Any]()) // Reference to whole DStream graph protected[streaming] var graph: DStreamGraph = null @@ -280,6 +282,13 @@ extends Serializable with Logging { dependencies.foreach(_.forgetOldRDDs(time)) } + /* Adds metadata to the Stream while it is running. + * This methd should be overwritten by sublcasses of InputDStream. + */ + protected[streaming] def addMetadata(metadata: Any) { + logInfo("Dropping Metadata: " + metadata.toString) + } + /** * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of * this stream. This is an internal method that should not be called directly. This is @@ -288,22 +297,22 @@ extends Serializable with Logging { * this method to save custom checkpoint data. */ protected[streaming] def updateCheckpointData(currentTime: Time) { - val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) + val newRdds = generatedRDDs.filter(_._2.getCheckpointData() != null) .map(x => (x._1, x._2.getCheckpointData())) - val oldCheckpointData = checkpointData.clone() - if (newCheckpointData.size > 0) { - checkpointData.clear() - checkpointData ++= newCheckpointData + val oldRdds = checkpointData.rdds.clone() + if (newRdds.size > 0) { + checkpointData.rdds.clear() + checkpointData.rdds ++= newRdds } dependencies.foreach(_.updateCheckpointData(currentTime)) - newCheckpointData.foreach { + newRdds.foreach { case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } } - if (newCheckpointData.size > 0) { - (oldCheckpointData -- newCheckpointData.keySet).foreach { + if (newRdds.size > 0) { + (oldRdds -- newRdds.keySet).foreach { case (time, data) => { val path = new Path(data.toString) val fs = path.getFileSystem(new Configuration()) @@ -322,8 +331,8 @@ extends Serializable with Logging { * override the updateCheckpointData() method would also need to override this method. */ protected[streaming] def restoreCheckpointData() { - logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") - checkpointData.foreach { + logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs") + checkpointData.rdds.foreach { case(time, data) => { logInfo("Restoring checkpointed RDD for time " + time + " from file") generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala new file mode 100644 index 0000000000..05f307a8d1 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DataHandler.scala @@ -0,0 +1,83 @@ +package spark.streaming + +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.Logging +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + + +/** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) + extends Serializable with Logging { + + case class Block(id: String, iterator: Iterator[T], metadata: Any = null) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def createBlock(blockId: String, iterator: Iterator[T]) : Block = { + new Block(blockId, iterator) + } + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) + val newBlock = createBlock(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index f3f4c3ab13..d3f37b8b0e 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -4,9 +4,11 @@ import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkEnv, RDD} import spark.rdd.BlockRDD +import spark.streaming.util.{RecurringTimer, SystemClock} import spark.storage.StorageLevel import java.nio.ByteBuffer +import java.util.concurrent.ArrayBlockingQueue import akka.actor.{Props, Actor} import akka.pattern.ask @@ -41,10 +43,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming sealed trait NetworkReceiverMessage case class StopReceiver(msg: String) extends NetworkReceiverMessage -case class ReportBlock(blockId: String) extends NetworkReceiverMessage +case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage case class ReportError(msg: String) extends NetworkReceiverMessage -abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging { +abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { initLogging() @@ -106,21 +108,23 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ actor ! ReportError(e.toString) } + /** * This method pushes a block (as iterator of values) into the block manager. */ - protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) { + def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { val buffer = new ArrayBuffer[T] ++ iterator env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) - actor ! ReportBlock(blockId) + + actor ! ReportBlock(blockId, metadata) } /** * This method pushes a block (as bytes) into the block manager. */ - protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId) + actor ! ReportBlock(blockId, metadata) } /** A helper actor that communicates with the NetworkInputTracker */ @@ -138,8 +142,8 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ } override def receive() = { - case ReportBlock(blockId) => - tracker ! AddBlocks(streamId, Array(blockId)) + case ReportBlock(blockId, metadata) => + tracker ! AddBlocks(streamId, Array(blockId), metadata) case ReportError(msg) => tracker ! DeregisterReceiver(streamId, msg) case StopReceiver(msg) => @@ -147,5 +151,6 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ tracker ! DeregisterReceiver(streamId, msg) } } + } diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 07ef79415d..4d9346edd8 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -13,7 +13,7 @@ import akka.dispatch._ trait NetworkInputTrackerMessage case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage -case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage +case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage @@ -22,7 +22,7 @@ class NetworkInputTracker( @transient networkInputStreams: Array[NetworkInputDStream[_]]) extends Logging { - val networkInputStreamIds = networkInputStreams.map(_.id).toArray + val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] val receivedBlockIds = new HashMap[Int, Queue[String]] @@ -53,14 +53,14 @@ class NetworkInputTracker( private class NetworkInputTrackerActor extends Actor { def receive = { case RegisterReceiver(streamId, receiverActor) => { - if (!networkInputStreamIds.contains(streamId)) { + if (!networkInputStreamMap.contains(streamId)) { throw new Exception("Register received for unexpected id " + streamId) } receiverInfo += ((streamId, receiverActor)) logInfo("Registered receiver for network stream " + streamId) sender ! true } - case AddBlocks(streamId, blockIds) => { + case AddBlocks(streamId, blockIds, metadata) => { val tmp = receivedBlockIds.synchronized { if (!receivedBlockIds.contains(streamId)) { receivedBlockIds += ((streamId, new Queue[String])) @@ -70,6 +70,7 @@ class NetworkInputTracker( tmp.synchronized { tmp ++= blockIds } + networkInputStreamMap(streamId).addMetadata(metadata) } case DeregisterReceiver(streamId, msg) => { receiverInfo -= streamId diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index e022b85fbe..90d8528d5b 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -48,7 +48,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S val buffer = queue.take() val blockId = "input-" + streamId + "-" + nextBlockNumber nextBlockNumber += 1 - pushBlock(blockId, buffer, storageLevel) + pushBlock(blockId, buffer, null, storageLevel) } } } diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index b566200273..ff99d50b76 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -32,7 +32,7 @@ class SocketReceiver[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkReceiver[T](streamId) { - lazy protected val dataHandler = new DataHandler(this) + lazy protected val dataHandler = new DataHandler(this, storageLevel) protected def onStart() { logInfo("Connecting to " + host + ":" + port) @@ -50,74 +50,6 @@ class SocketReceiver[T: ClassManifest]( dataHandler.stop() } - /** - * This is a helper object that manages the data received from the socket. It divides - * the object received into small batches of 100s of milliseconds, pushes them as - * blocks into the block manager and reports the block IDs to the network input - * tracker. It starts two threads, one to periodically start a new batch and prepare - * the previous batch of as a block, the other to push the blocks into the block - * manager. - */ - class DataHandler(receiver: NetworkReceiver[T]) extends Serializable { - case class Block(id: String, iterator: Iterator[T]) - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + streamId + "- " + (time - blockInterval) - val newBlock = new Block(blockId, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - receiver.stop() - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - pushBlock(block.id, block.iterator, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - receiver.stop() - } - } - } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 05c83d6c08..770fd61498 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -106,9 +106,11 @@ class StreamingContext ( hostname: String, port: Int, groupId: String, + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, storageLevel) + val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel) graph.addInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index 3f637150d1..655f9627b3 100644 --- a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -1,6 +1,6 @@ package spark.streaming.examples -import spark.streaming.{Seconds, StreamingContext, KafkaInputDStream} +import spark.streaming._ import spark.streaming.StreamingContext._ import spark.storage.StorageLevel @@ -17,11 +17,20 @@ object KafkaWordCount { // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.kafkaStream[String](args(1), args(2).toInt, "test_group") + ssc.checkpoint("checkpoint", Time(1000 * 5)) + val lines = ssc.kafkaStream[String](args(1), args(2).toInt, "test_group", Map("test" -> 1), + Map(KafkaPartitionKey(0, "test", "test_group", 0) -> 2382)) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() ssc.start() + // Wait for 12 seconds + Thread.sleep(12000) + ssc.stop() + + val newSsc = new StreamingContext("checkpoint") + newSsc.start() + } } diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala index 427f398237..814f2706d6 100644 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -1,121 +1,164 @@ package spark.streaming +import java.lang.reflect.Method import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{ArrayBlockingQueue, Executors} +import java.util.concurrent.{ArrayBlockingQueue, ConcurrentHashMap, Executors} import kafka.api.{FetchRequest} -import kafka.consumer.{Consumer, ConsumerConfig, KafkaStream} -import kafka.javaapi.consumer.SimpleConsumer -import kafka.javaapi.message.ByteBufferMessageSet +import kafka.consumer._ +import kafka.cluster.Partition import kafka.message.{Message, MessageSet, MessageAndMetadata} -import kafka.utils.Utils +import kafka.serializer.StringDecoder +import kafka.utils.{Pool, Utils, ZKGroupTopicDirs} +import kafka.utils.ZkUtils._ +import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ import spark._ import spark.RDD import spark.storage.StorageLevel +case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) +case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) +case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], + savedOffsets: HashMap[Long, Map[KafkaPartitionKey, Long]]) extends DStreamCheckpointData(kafkaRdds) + /** - * An input stream that pulls messages form a Kafka Broker. + * Input stream that pulls messages form a Kafka Broker. */ class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, host: String, port: Int, groupId: String, - storageLevel: StorageLevel, - timeout: Int = 10000, - bufferSize: Int = 1024000 + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_ ) with Logging { + var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() + + override protected[streaming] def addMetadata(metadata: Any) { + metadata match { + case x : KafkaInputDStreamMetadata => + savedOffsets(x.timestamp) = x.data + logInfo("Saved Offsets: " + savedOffsets) + case _ => logInfo("Received unknown metadata: " + metadata.toString) + } + } + + override protected[streaming] def updateCheckpointData(currentTime: Time) { + super.updateCheckpointData(currentTime) + logInfo("Updating KafkaDStream checkpoint data: " + savedOffsets.toString) + checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, savedOffsets) + } + + override protected[streaming] def restoreCheckpointData() { + super.restoreCheckpointData() + logInfo("Restoring KafkaDStream checkpoint data.") + checkpointData match { + case x : KafkaDStreamCheckpointData => + savedOffsets = x.savedOffsets + logInfo("Restored KafkaDStream offsets: " + savedOffsets.toString) + } + } + def createReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(id, host, port, storageLevel, groupId, timeout).asInstanceOf[NetworkReceiver[T]] + new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) + .asInstanceOf[NetworkReceiver[T]] } } -class KafkaReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel, groupId: String, timeout: Int) - extends NetworkReceiver[Any](streamId) { +class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, + topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { + + // Timeout for establishing a connection to Zookeper in ms. + val ZK_TIMEOUT = 10000 - //var executorPool : = null - var blockPushingThread : Thread = null + // Handles pushing data into the BlockManager + lazy protected val dataHandler = new KafkaDataHandler(this, storageLevel) + // Keeps track of the current offsets. Maps from (topic, partitionID) -> Offset + lazy val offsets = HashMap[KafkaPartitionKey, Long]() + // Connection to Kafka + var consumerConnector : ZookeeperConsumerConnector = null def onStop() { - blockPushingThread.interrupt() + dataHandler.stop() } def onStart() { - val executorPool = Executors.newFixedThreadPool(2) + // Starting the DataHandler that buffers blocks and pushes them into them BlockManager + dataHandler.start() - logInfo("Starting Kafka Consumer with groupId " + groupId) + // In case we are using multiple Threads to handle Kafka Messages + val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) val zooKeeperEndPoint = host + ":" + port + logInfo("Starting Kafka Consumer Stream in group " + groupId) + logInfo("Initial offsets: " + initialOffsets.toString) logInfo("Connecting to " + zooKeeperEndPoint) - - // Specify some consumer properties + // Specify some Consumer properties val props = new Properties() props.put("zk.connect", zooKeeperEndPoint) - props.put("zk.connectiontimeout.ms", timeout.toString) + props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) props.put("groupid", groupId) // Create the connection to the cluster val consumerConfig = new ConsumerConfig(props) - val consumerConnector = Consumer.create(consumerConfig) - logInfo("Connected to " + zooKeeperEndPoint) - logInfo("") - logInfo("") + consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - // Specify which topics we are listening to - val topicCountMap = Map("test" -> 2) - val topicMessageStreams = consumerConnector.createMessageStreams(topicCountMap) - val streams = topicMessageStreams.get("test") + // Reset the Kafka offsets in case we are recovering from a failure + resetOffsets(initialOffsets) - // Queue that holds the blocks - val queue = new ArrayBlockingQueue[ByteBuffer](2) + logInfo("Connected to " + zooKeeperEndPoint) - streams.getOrElse(Nil).foreach { stream => - executorPool.submit(new MessageHandler(stream, queue)) + // Create Threads for each Topic/Message Stream we are listening + val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } } - blockPushingThread = new DaemonThread { - override def run() { - logInfo("Starting BlockPushingThread.") - var nextBlockNumber = 0 - while (true) { - val buffer = queue.take() - val blockId = "input-" + streamId + "-" + nextBlockNumber - nextBlockNumber += 1 - pushBlock(blockId, buffer, storageLevel) - } - } + } + + // Overwrites the offets in Zookeper. + private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { + offsets.foreach { case(key, offset) => + val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) + val partitionName = key.brokerId + "-" + key.partId + updatePersistentPath(consumerConnector.zkClient, + topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) } - blockPushingThread.start() - - // while (true) { - // // Create a fetch request for topic “test”, partition 0, current offset, and fetch size of 1MB - // val fetchRequest = new FetchRequest("test", 0, offset, 1000000) - - // // get the message set from the consumer and print them out - // val messages = consumer.fetch(fetchRequest) - // for(msg <- messages.iterator) { - // logInfo("consumed: " + Utils.toString(msg.message.payload, "UTF-8")) - // // advance the offset after consuming each message - // offset = msg.offset - // queue.put(msg.message.payload) - // } - // } } - class MessageHandler(stream: KafkaStream[Message], queue: ArrayBlockingQueue[ByteBuffer]) extends Runnable { + // Responsible for handling Kafka Messages + class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { logInfo("Starting MessageHandler.") - while(true) { - stream.foreach { msgAndMetadata => - logInfo("Consumed: " + Utils.toString(msgAndMetadata.message.payload, "UTF-8")) - queue.put(msgAndMetadata.message.payload) - } - } + stream.takeWhile { msgAndMetadata => + dataHandler += msgAndMetadata.message + + // Updating the offet. The key is (topic, partitionID). + val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, + groupId, msgAndMetadata.topicInfo.partition.partId) + val offset = msgAndMetadata.topicInfo.getConsumeOffset + offsets.put(key, offset) + logInfo((key, offset).toString) + + // Keep on handling messages + true + } } } + class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) + extends DataHandler[Any](receiver, storageLevel) { + + override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { + new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) + } + + } } -- cgit v1.2.3 From 355c8e4b17cc3e67b1e18cc24e74d88416b5779b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 9 Nov 2012 16:28:45 -0800 Subject: Fixed deadlock in BlockManager. --- .../main/scala/spark/storage/BlockManager.scala | 111 ++++++++++----------- .../src/main/scala/spark/storage/MemoryStore.scala | 79 +++++++++------ 2 files changed, 101 insertions(+), 89 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index bd9155ef29..8c7b1417be 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -50,16 +50,6 @@ private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) - -private[spark] class BlockLocker(numLockers: Int) { - private val hashLocker = Array.fill(numLockers)(new Object()) - - def getLock(blockId: String): Object = { - return hashLocker(math.abs(blockId.hashCode % numLockers)) - } -} - - private[spark] class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, maxMemory: Long) extends Logging { @@ -87,10 +77,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - private val NUM_LOCKS = 337 - private val locker = new BlockLocker(NUM_LOCKS) - - private val blockInfo = new ConcurrentHashMap[String, BlockInfo]() + private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000) private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -110,7 +97,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val maxBytesInFlight = System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + // Whether to compress broadcast variables that are stored val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean + // Whether to compress shuffle output that are stored val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean // Whether to compress RDD partitions that are stored serialized val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean @@ -150,28 +139,27 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. */ def reportBlockStatus(blockId: String) { - locker.getLock(blockId).synchronized { - val curLevel = blockInfo.get(blockId) match { - case null => - StorageLevel.NONE - case info => + val (curLevel, inMemSize, onDiskSize) = blockInfo.get(blockId) match { + case null => + (StorageLevel.NONE, 0L, 0L) + case info => + info.synchronized { info.level match { case null => - StorageLevel.NONE + (StorageLevel.NONE, 0L, 0L) case level => val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) - new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + ( + new StorageLevel(onDisk, inMem, level.deserialized, level.replication), + if (inMem) memoryStore.getSize(blockId) else 0L, + if (onDisk) diskStore.getSize(blockId) else 0L + ) } - } - master.mustHeartBeat(HeartBeat( - blockManagerId, - blockId, - curLevel, - if (curLevel.useMemory) memoryStore.getSize(blockId) else 0L, - if (curLevel.useDisk) diskStore.getSize(blockId) else 0L)) - logDebug("Told master about block " + blockId) + } } + master.mustHeartBeat(HeartBeat(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)) + logDebug("Told master about block " + blockId) } /** @@ -213,9 +201,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - if (info != null) { + val info = blockInfo.get(blockId) + if (info != null) { + info.synchronized { info.waitForReady() // In case the block is still being put() by another thread val level = info.level logDebug("Level for block " + blockId + " is " + level) @@ -273,9 +261,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } } - } else { - logDebug("Block " + blockId + " not registered locally") } + } else { + logDebug("Block " + blockId + " not registered locally") } return None } @@ -298,9 +286,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - if (info != null) { + val info = blockInfo.get(blockId) + if (info != null) { + info.synchronized { info.waitForReady() // In case the block is still being put() by another thread val level = info.level logDebug("Level for block " + blockId + " is " + level) @@ -338,10 +326,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new Exception("Block " + blockId + " not found on disk, though it should be") } } - } else { - logDebug("Block " + blockId + " not registered locally") } + } else { + logDebug("Block " + blockId + " not registered locally") } + return None } @@ -583,7 +572,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Size of the block in bytes (to return to caller) var size = 0L - locker.getLock(blockId).synchronized { + myInfo.synchronized { logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") @@ -681,7 +670,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m null } - locker.getLock(blockId).synchronized { + myInfo.synchronized { logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") @@ -779,26 +768,30 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - val level = info.level - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo("Writing block " + blockId + " to disk") - data match { - case Left(elements) => - diskStore.putValues(blockId, elements, level, false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) + val info = blockInfo.get(blockId) + if (info != null) { + info.synchronized { + val level = info.level + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo("Writing block " + blockId + " to disk") + data match { + case Left(elements) => + diskStore.putValues(blockId, elements, level, false) + case Right(bytes) => + diskStore.putBytes(blockId, bytes, level) + } + } + memoryStore.remove(blockId) + if (info.tellMaster) { + reportBlockStatus(blockId) + } + if (!level.useDisk) { + // The block is completely gone from this node; forget it so we can put() it again later. + blockInfo.remove(blockId) } } - memoryStore.remove(blockId) - if (info.tellMaster) { - reportBlockStatus(blockId) - } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) - } + } else { + // The block has already been dropped } } diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 773970446a..09769d1f7d 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -17,13 +17,16 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) private var currentMemory = 0L + // Object used to ensure that only one thread is putting blocks and if necessary, dropping + // blocks from the memory store. + private val putLock = new Object() logInfo("MemoryStore started with capacity %s.".format(Utils.memoryBytesToString(maxMemory))) def freeMemory: Long = maxMemory - currentMemory override def getSize(blockId: String): Long = { - synchronized { + entries.synchronized { entries.get(blockId).size } } @@ -38,8 +41,6 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) tryToPut(blockId, elements, sizeEstimate, true) } else { val entry = new Entry(bytes, bytes.limit, false) - ensureFreeSpace(blockId, bytes.limit) - synchronized { entries.put(blockId, entry) } tryToPut(blockId, bytes, bytes.limit, false) } } @@ -63,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def getBytes(blockId: String): Option[ByteBuffer] = { - val entry = synchronized { + val entry = entries.synchronized { entries.get(blockId) } if (entry == null) { @@ -76,7 +77,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def getValues(blockId: String): Option[Iterator[Any]] = { - val entry = synchronized { + val entry = entries.synchronized { entries.get(blockId) } if (entry == null) { @@ -90,7 +91,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def remove(blockId: String) { - synchronized { + entries.synchronized { val entry = entries.get(blockId) if (entry != null) { entries.remove(blockId) @@ -104,7 +105,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def clear() { - synchronized { + entries.synchronized { entries.clear() } logInfo("MemoryStore cleared") @@ -125,12 +126,22 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Try to put in a set of values, if we can free up enough space. The value should either be * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) * size must also be passed by the caller. + * + * Locks on the object putLock to ensure that all the put requests and its associated block + * dropping is done by only on thread at a time. Otherwise while one thread is dropping + * blocks to free memory for one block, another thread may use up the freed space for + * another block. */ private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { - synchronized { + // TODO: Its possible to optimize the locking by locking entries only when selecting blocks + // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been + // released, it must be ensured that those to-be-dropped blocks are not double counted for + // freeing up more space for another block that needs to be put. Only then the actually dropping + // of blocks (and writing to disk if necessary) can proceed in parallel. + putLock.synchronized { if (ensureFreeSpace(blockId, size)) { val entry = new Entry(value, size, deserialized) - entries.put(blockId, entry) + entries.synchronized { entries.put(blockId, entry) } currentMemory += size if (deserialized) { logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( @@ -160,8 +171,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * - * Assumes that a lock on the MemoryStore is held by the caller. (Otherwise, the freed space - * might fill up before the caller puts in their new value.) + * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. + * Otherwise, the freed space may fill up before the caller puts in their new value. */ private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( @@ -172,36 +183,44 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return false } - // TODO: This should relinquish the lock on the MemoryStore while flushing out old blocks - // in order to allow parallelism in writing to disk if (maxMemory - currentMemory < space) { val rddToAdd = getRddId(blockIdToAdd) val selectedBlocks = new ArrayBuffer[String]() var selectedMemory = 0L - val iterator = entries.entrySet().iterator() - while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { - val pair = iterator.next() - val blockId = pair.getKey - if (rddToAdd != null && rddToAdd == getRddId(blockId)) { - logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + - "block from the same RDD") - return false + // This is synchronized to ensure that the set of entries is not changed + // (because of getValue or getBytes) while traversing the iterator, as that + // can lead to exceptions. + entries.synchronized { + val iterator = entries.entrySet().iterator() + while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { + val pair = iterator.next() + val blockId = pair.getKey + if (rddToAdd != null && rddToAdd == getRddId(blockId)) { + logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + + "block from the same RDD") + return false + } + selectedBlocks += blockId + selectedMemory += pair.getValue.size } - selectedBlocks += blockId - selectedMemory += pair.getValue.size } if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - val entry = entries.get(blockId) - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) } - blockManager.dropFromMemory(blockId, data) } return true } else { @@ -212,7 +231,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def contains(blockId: String): Boolean = { - synchronized { entries.containsKey(blockId) } + entries.synchronized { entries.containsKey(blockId) } } } -- cgit v1.2.3 From 04e9e9d93c512f856116bc2c99c35dfb48b4adee Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 11 Nov 2012 08:54:21 -0800 Subject: Refactored BlockManagerMaster (not BlockManagerMasterActor) to simplify the code and fix live lock problem in unlimited attempts to contact the master. Also added testcases in the BlockManagerSuite to test BlockManagerMaster methods getPeers and getLocations. --- .../main/scala/spark/storage/BlockManager.scala | 14 +- .../scala/spark/storage/BlockManagerMaster.scala | 281 +++++++-------------- .../scala/spark/storage/BlockManagerSuite.scala | 30 ++- 3 files changed, 127 insertions(+), 198 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 8c7b1417be..70d6d8369d 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -120,8 +120,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * BlockManagerWorker actor. */ private def initialize() { - master.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory)) + master.registerBlockManager(blockManagerId, maxMemory) BlockManagerWorker.startBlockManagerWorker(this) } @@ -158,7 +157,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } } - master.mustHeartBeat(HeartBeat(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)) + master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize) logDebug("Told master about block " + blockId) } @@ -167,7 +166,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis - var managers = master.mustGetLocations(GetLocations(blockId)) + var managers = master.getLocations(blockId) val locations = managers.map(_.ip) logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations @@ -178,8 +177,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = master.mustGetLocationsMultipleBlockIds( - GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray + val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -343,7 +341,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } logDebug("Getting remote block " + blockId) // Get locations of block - val locations = master.mustGetLocations(GetLocations(blockId)) + val locations = master.getLocations(blockId) // Get block from remote locations for (loc <- locations) { @@ -721,7 +719,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { - cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index b3345623b3..4d5ee8318c 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -26,7 +26,7 @@ case class RegisterBlockManager( extends ToBlockManagerMaster private[spark] -class HeartBeat( +class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: String, var storageLevel: StorageLevel, @@ -57,17 +57,17 @@ class HeartBeat( } private[spark] -object HeartBeat { +object UpdateBlockInfo { def apply(blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long): HeartBeat = { - new HeartBeat(blockManagerId, blockId, storageLevel, memSize, diskSize) + diskSize: Long): UpdateBlockInfo = { + new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) } // For pattern-matching - def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) } } @@ -182,8 +182,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor case RegisterBlockManager(blockManagerId, maxMemSize) => register(blockManagerId, maxMemSize) - case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) => - heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) + case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) case GetLocations(blockId) => getLocations(blockId) @@ -233,7 +233,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor sender ! true } - private def heartBeat( + private def updateBlockInfo( blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, @@ -245,7 +245,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (blockId == null) { blockManagerInfo(blockManagerId).updateLastSeenMs() - logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got in updateBlockInfo 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) sender ! true } @@ -350,211 +350,124 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) extends Logging { - val AKKA_ACTOR_NAME: String = "BlockMasterManager" - val REQUEST_RETRY_INTERVAL_MS = 100 - val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost") - val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt - val DEFAULT_MANAGER_IP: String = Utils.localHostName() - val DEFAULT_MANAGER_PORT: String = "10902" - + val actorName = "BlockMasterManager" val timeout = 10.seconds - var masterActor: ActorRef = null + val maxAttempts = 5 - if (isMaster) { - masterActor = actorSystem.actorOf( - Props(new BlockManagerMasterActor(isLocal)), name = AKKA_ACTOR_NAME) + var masterActor = if (isMaster) { + val actor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), name = actorName) logInfo("Registered BlockManagerMaster Actor") + actor } else { - val url = "akka://spark@%s:%s/user/%s".format( - DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME) + val host = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val url = "akka://spark@%s:%s/user/%s".format(host, port, actorName) + val actor = actorSystem.actorFor(url) logInfo("Connecting to BlockManagerMaster: " + url) - masterActor = actorSystem.actorFor(url) + actor } - def stop() { - if (masterActor != null) { - communicate(StopBlockManagerMaster) - masterActor = null - logInfo("BlockManagerMaster stopped") + /** + * Send a message to the master actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + private def ask[T](message: Any): T = { + // TODO: Consider removing multiple attempts + if (masterActor == null) { + throw new SparkException("Error sending message to BlockManager as masterActor is null " + + "[message = " + message + "]") } - } - - // Send a message to the master actor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askMaster(message: Any): Any = { - try { - val future = masterActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with BlockManagerMaster", e) + var attempts = 0 + var lastException: Exception = null + while (attempts < maxAttempts) { + attempts += 1 + try { + val future = masterActor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new Exception("BlockManagerMaster returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => + throw ie + case e: Exception => + lastException = e + logWarning( + "Error sending message to BlockManagerMaster in " + attempts + " attempts", e) + } + Thread.sleep(100) } + throw new SparkException( + "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) } - // Send a one-way message to the master actor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askMaster(message) != true) { - throw new SparkException("Error reply received from BlockManagerMaster") + /** + * Send a one-way message to the master actor, to which we expect it to reply with true + */ + private def tell(message: Any) { + if (!ask[Boolean](message)) { + throw new SparkException("Telling master a message returned false") } } - def notifyADeadHost(host: String) { - communicate(RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)) - logInfo("Removed " + host + " successfully in notifyADeadHost") - } - - def mustRegisterBlockManager(msg: RegisterBlockManager) { + /** + * Register the BlockManager's id with the master + */ + def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long) { logInfo("Trying to register BlockManager") - while (! syncRegisterBlockManager(msg)) { - logWarning("Failed to register " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - logInfo("Done registering BlockManager") - } - - def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { - //val masterActor = RemoteActor.select(node, name) - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logInfo("BlockManager registered successfully @ syncRegisterBlockManager") - logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncRegisterBlockManager", e) - return false - } + tell(RegisterBlockManager(blockManagerId, maxMemSize)) + logInfo("Registered BlockManager") } - def mustHeartBeat(msg: HeartBeat) { - while (! syncHeartBeat(msg)) { - logWarning("Failed to send heartbeat" + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - } - - def syncHeartBeat(msg: HeartBeat): Boolean = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logDebug("Heartbeat sent successfully") - logDebug("Got in syncHeartBeat 1 " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncHeartBeat", e) - return false - } + def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long + ) { + tell(UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) + logInfo("Updated info of block " + blockId) } - def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - var res = syncGetLocations(msg) - while (res == null) { - logInfo("Failed to get locations " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocations(msg) - } - return res + /** Get locations of the blockId from the master */ + def getLocations(blockId: String): Seq[BlockManagerId] = { + ask[Seq[BlockManagerId]](GetLocations(blockId)) } - def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[ArrayBuffer[BlockManagerId]] - if (answer != null) { - logDebug("GetLocations successful") - logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocations") - return null - } - } catch { - case e: Exception => - logError("GetLocations failed", e) - return null - } + /** Get locations of multiple blockIds from the master */ + def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + ask[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } - def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg) - while (res == null) { - logWarning("Failed to GetLocationsMultipleBlockIds " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocationsMultipleBlockIds(msg) + /** Get ids of other nodes in the cluster from the master */ + def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { + val result = ask[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + if (result.length != numPeers) { + throw new SparkException( + "Error getting peers, only got " + result.size + " instead of " + numPeers) } - return res + result } - def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]] - if (answer != null) { - logDebug("GetLocationsMultipleBlockIds successful") - logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + - Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocationsMultipleBlockIds") - return null - } - } catch { - case e: Exception => - logError("GetLocationsMultipleBlockIds failed", e) - return null - } + /** Notify the master of a dead node */ + def notifyADeadHost(host: String) { + tell(RemoveHost(host + ":10902")) + logInfo("Told BlockManagerMaster to remove dead host " + host) } - def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - var res = syncGetPeers(msg) - while ((res == null) || (res.length != msg.size)) { - logInfo("Failed to get peers " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetPeers(msg) - } - - return res + /** Get the memory status form the master */ + def getMemoryStatus(): Map[BlockManagerId, (Long, Long)] = { + ask[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } - def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[BlockManagerId]] - if (answer != null) { - logDebug("GetPeers successful") - logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetPeers") - return null - } - } catch { - case e: Exception => - logError("GetPeers failed", e) - return null + /** Stop the master actor, called only on the Spark master node */ + def stop() { + if (masterActor != null) { + tell(StopBlockManagerMaster) + masterActor = null + logInfo("BlockManagerMaster stopped") } } - - def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askMaster(GetMemoryStatus).asInstanceOf[Map[BlockManagerId, (Long, Long)]] - } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index b9c19e61cd..0e78228134 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -20,9 +20,11 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var oldOops: String = null // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + System.setProperty("spark.kryoserializer.buffer.mb", "1") val serializer = new KryoSerializer before { + actorSystem = ActorSystem("test") master = new BlockManagerMaster(actorSystem, true, true) @@ -55,7 +57,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } - test("manager-master interaction") { + test("master + 1 manager interaction") { store = new BlockManager(master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -72,17 +74,33 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a3") != None, "a3 was not in store") // Checking whether master knows about the blocks or not - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") - assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") - assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + assert(master.getLocations("a1").size === 1, "master was not told about a1") + assert(master.getLocations("a2").size === 1, "master was not told about a2") + assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master store.dropFromMemory("a1", null) store.dropFromMemory("a2", null) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") - assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") + assert(master.getLocations("a1").size === 0, "master did not remove a1") + assert(master.getLocations("a2").size === 0, "master did not remove a2") + } + + test("master + 2 managers interaction") { + store = new BlockManager(master, serializer, 2000) + val otherStore = new BlockManager(master, new KryoSerializer, 2000) + + val peers = master.getPeers(store.blockManagerId, 1) + assert(peers.size === 1, "master did not return the other manager as a peer") + assert(peers.head === otherStore.blockManagerId, "peer returned by master is not the other manager") + + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) + otherStore.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) + assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") + assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") } test("in-memory LRU storage") { -- cgit v1.2.3 From d006109e9504b3221de3a15f9bfee96dafa8b593 Mon Sep 17 00:00:00 2001 From: Denny Date: Sun, 11 Nov 2012 11:06:49 -0800 Subject: Kafka Stream comments. --- .../src/main/scala/spark/streaming/DStream.scala | 7 +- .../scala/spark/streaming/StreamingContext.scala | 12 ++++ .../spark/streaming/examples/KafkaWordCount.scala | 44 +++++++------ .../spark/streaming/input/KafkaInputDStream.scala | 77 +++++++++++++++------- .../scala/spark/streaming/CheckpointSuite.scala | 12 ++-- 5 files changed, 99 insertions(+), 53 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 3219919a24..b8324d11a3 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -17,6 +17,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration + case class DStreamCheckpointData(rdds: HashMap[Time, Any]) abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) @@ -61,7 +62,7 @@ extends Serializable with Logging { // Checkpoint details protected[streaming] val mustCheckpoint = false protected[streaming] var checkpointInterval: Time = null - protected[streaming] var checkpointData = DStreamCheckpointData(HashMap[Time, Any]()) + protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]()) // Reference to whole DStream graph protected[streaming] var graph: DStreamGraph = null @@ -286,7 +287,9 @@ extends Serializable with Logging { * This methd should be overwritten by sublcasses of InputDStream. */ protected[streaming] def addMetadata(metadata: Any) { - logInfo("Dropping Metadata: " + metadata.toString) + if (metadata != null) { + logInfo("Dropping Metadata: " + metadata.toString) + } } /** diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index d68d2632e7..e87d0cb7c8 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -102,6 +102,18 @@ final class StreamingContext ( private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + /** + * Create an input stream that pulls messages form a Kafka Broker. + * + * @param host Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * By default the value is pulled from zookeper. + * @param storageLevel RDD storage level. Defaults to memory-only. + */ def kafkaStream[T: ClassManifest]( hostname: String, port: Int, diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index 655f9627b3..1e92cbb210 100644 --- a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -3,34 +3,38 @@ package spark.streaming.examples import spark.streaming._ import spark.streaming.StreamingContext._ import spark.storage.StorageLevel +import WordCount2_ExtraFunctions._ object KafkaWordCount { def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCountNetwork ") + + if (args.length < 4) { + System.err.println("Usage: KafkaWordCount ") System.exit(1) } - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "WordCountNetwork") - ssc.setBatchDuration(Seconds(2)) + val ssc = args(3) match { + // Restore the stream from a checkpoint + case "true" => + new StreamingContext("work/checkpoint") + case _ => + val tmp = new StreamingContext(args(0), "KafkaWordCount") - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - ssc.checkpoint("checkpoint", Time(1000 * 5)) - val lines = ssc.kafkaStream[String](args(1), args(2).toInt, "test_group", Map("test" -> 1), - Map(KafkaPartitionKey(0, "test", "test_group", 0) -> 2382)) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() + tmp.setBatchDuration(Seconds(2)) + tmp.checkpoint("work/checkpoint", Seconds(10)) + + val lines = tmp.kafkaStream[String](args(1), args(2).toInt, "test_group", Map("test" -> 1), + Map(KafkaPartitionKey(0,"test","test_group",0) -> 0l)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) + + wordCounts.persist().checkpoint(Seconds(10)) + wordCounts.print() + + tmp + } ssc.start() - // Wait for 12 seconds - Thread.sleep(12000) - ssc.stop() - - val newSsc = new StreamingContext("checkpoint") - newSsc.start() - } } + diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala index 814f2706d6..ad8e86a094 100644 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -1,15 +1,11 @@ package spark.streaming -import java.lang.reflect.Method -import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{ArrayBlockingQueue, ConcurrentHashMap, Executors} -import kafka.api.{FetchRequest} +import java.util.concurrent.Executors import kafka.consumer._ -import kafka.cluster.Partition import kafka.message.{Message, MessageSet, MessageAndMetadata} import kafka.serializer.StringDecoder -import kafka.utils.{Pool, Utils, ZKGroupTopicDirs} +import kafka.utils.{Utils, ZKGroupTopicDirs} import kafka.utils.ZkUtils._ import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ @@ -17,14 +13,25 @@ import spark._ import spark.RDD import spark.storage.StorageLevel - +// Key for a specific Kafka Partition: (broker, topic, group, part) case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) +// Metadata for a Kafka Stream that it sent to the Master case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) +// Checkpoint data specific to a KafkaInputDstream case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], - savedOffsets: HashMap[Long, Map[KafkaPartitionKey, Long]]) extends DStreamCheckpointData(kafkaRdds) + savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) /** * Input stream that pulls messages form a Kafka Broker. + * + * @param host Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * By default the value is pulled from zookeper. + * @param storageLevel RDD storage level. */ class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, @@ -36,21 +43,31 @@ class KafkaInputDStream[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_ ) with Logging { + // Metadata that keeps track of which messages have already been consumed. var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() + // In case of a failure, the offets for a particular timestamp will be restored. + @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null override protected[streaming] def addMetadata(metadata: Any) { metadata match { - case x : KafkaInputDStreamMetadata => + case x : KafkaInputDStreamMetadata => savedOffsets(x.timestamp) = x.data - logInfo("Saved Offsets: " + savedOffsets) + // TOOD: Remove logging + logInfo("New saved Offsets: " + savedOffsets) case _ => logInfo("Received unknown metadata: " + metadata.toString) } } override protected[streaming] def updateCheckpointData(currentTime: Time) { super.updateCheckpointData(currentTime) - logInfo("Updating KafkaDStream checkpoint data: " + savedOffsets.toString) - checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, savedOffsets) + if(savedOffsets.size > 0) { + // Find the offets that were stored before the checkpoint was initiated + val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last + val latestOffsets = savedOffsets(key) + logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) + checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) + savedOffsets.clear() + } } override protected[streaming] def restoreCheckpointData() { @@ -58,14 +75,21 @@ class KafkaInputDStream[T: ClassManifest]( logInfo("Restoring KafkaDStream checkpoint data.") checkpointData match { case x : KafkaDStreamCheckpointData => - savedOffsets = x.savedOffsets - logInfo("Restored KafkaDStream offsets: " + savedOffsets.toString) + restoredOffsets = x.savedOffsets + logInfo("Restored KafkaDStream offsets: " + savedOffsets) } } def createReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) - .asInstanceOf[NetworkReceiver[T]] + // We have restored from a checkpoint, use the restored offsets + if (restoredOffsets != null) { + new KafkaReceiver(id, host, port, groupId, topics, restoredOffsets, storageLevel) + .asInstanceOf[NetworkReceiver[T]] + } else { + new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) + .asInstanceOf[NetworkReceiver[T]] + } + } } @@ -96,27 +120,28 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) val zooKeeperEndPoint = host + ":" + port - logInfo("Starting Kafka Consumer Stream in group " + groupId) + logInfo("Starting Kafka Consumer Stream with group: " + groupId) logInfo("Initial offsets: " + initialOffsets.toString) - logInfo("Connecting to " + zooKeeperEndPoint) - // Specify some Consumer properties + + // Zookeper connection properties val props = new Properties() props.put("zk.connect", zooKeeperEndPoint) props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) props.put("groupid", groupId) // Create the connection to the cluster + logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) val consumerConfig = new ConsumerConfig(props) consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] + logInfo("Connected to " + zooKeeperEndPoint) // Reset the Kafka offsets in case we are recovering from a failure resetOffsets(initialOffsets) - - logInfo("Connected to " + zooKeeperEndPoint) // Create Threads for each Topic/Message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } } @@ -133,19 +158,20 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, } } - // Responsible for handling Kafka Messages - class MessageHandler(stream: KafkaStream[String]) extends Runnable { + // Handles Kafka Messages + private class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { logInfo("Starting MessageHandler.") stream.takeWhile { msgAndMetadata => dataHandler += msgAndMetadata.message - // Updating the offet. The key is (topic, partitionID). + // Updating the offet. The key is (broker, topic, group, partition). val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, groupId, msgAndMetadata.topicInfo.partition.partId) val offset = msgAndMetadata.topicInfo.getConsumeOffset offsets.put(key, offset) - logInfo((key, offset).toString) + // TODO: Remove Logging + logInfo("Handled message: " + (key, offset).toString) // Keep on handling messages true @@ -157,6 +183,7 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, extends DataHandler[Any](receiver, storageLevel) { override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { + // Creates a new Block with Kafka-specific Metadata new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) } diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 038827ddb0..0450120061 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -59,9 +59,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // then check whether some RDD has been checkpointed or not ssc.start() runStreamsWithRealDelay(ssc, firstNumBatches) - logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]") - assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before first failure") - stateStream.checkpointData.foreach { + logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.rdds.mkString(",\n") + "]") + assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before first failure") + stateStream.checkpointData.rdds.foreach { case (time, data) => { val file = new File(data.toString) assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") @@ -70,7 +70,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run till a further time such that previous checkpoint files in the stream would be deleted // and check whether the earlier checkpoint files are deleted - val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString)) + val checkpointFiles = stateStream.checkpointData.rdds.map(x => new File(x._2.toString)) runStreamsWithRealDelay(ssc, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) ssc.stop() @@ -87,8 +87,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // is present in the checkpoint data or not ssc.start() runStreamsWithRealDelay(ssc, 1) - assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before second failure") - stateStream.checkpointData.foreach { + assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before second failure") + stateStream.checkpointData.rdds.foreach { case (time, data) => { val file = new File(data.toString) assert(file.exists(), -- cgit v1.2.3 From deb2c4df72f65f2bd90cc97a9abcd59a12eecabc Mon Sep 17 00:00:00 2001 From: Denny Date: Sun, 11 Nov 2012 11:11:49 -0800 Subject: Add comment. --- streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala index ad8e86a094..cc74855983 100644 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -66,6 +66,8 @@ class KafkaInputDStream[T: ClassManifest]( val latestOffsets = savedOffsets(key) logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) + // TODO: This may throw out offsets that are created after the checkpoint, + // but it's unlikely we'll need them. savedOffsets.clear() } } -- cgit v1.2.3 From 0fd4c93f1c349f052f633fea64f975d53976bd9c Mon Sep 17 00:00:00 2001 From: Denny Date: Sun, 11 Nov 2012 11:15:31 -0800 Subject: Updated comment. --- streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala index cc74855983..318537532c 100644 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -104,7 +104,7 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, // Handles pushing data into the BlockManager lazy protected val dataHandler = new KafkaDataHandler(this, storageLevel) - // Keeps track of the current offsets. Maps from (topic, partitionID) -> Offset + // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset lazy val offsets = HashMap[KafkaPartitionKey, Long]() // Connection to Kafka var consumerConnector : ZookeeperConsumerConnector = null -- cgit v1.2.3 From 46222dc56db4a521bd613bd3fac5b91868bb339e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 11 Nov 2012 13:20:09 -0800 Subject: Fixed bug in FileInputDStream that allowed it to miss new files. Added tests in the InputStreamsSuite to test checkpointing of file and network streams. --- .../scala/spark/streaming/FileInputDStream.scala | 34 +++++- .../scala/spark/streaming/InputStreamsSuite.scala | 136 ++++++++++++++++++--- 2 files changed, 148 insertions(+), 22 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala index 9d7361097b..88856364d2 100644 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala @@ -6,7 +6,8 @@ import spark.rdd.UnionRDD import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import java.io.{ObjectInputStream, IOException} + +import scala.collection.mutable.HashSet class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( @@ -19,7 +20,8 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null - var lastModTime: Long = 0 + var lastModTime = 0L + val lastModTimeFiles = new HashSet[String]() def path(): Path = { if (path_ == null) path_ = new Path(directory) @@ -40,22 +42,37 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } override def stop() { } - + + /** + * Finds the files that were modified since the last time this method was called and makes + * a union RDD out of them. Note that this maintains the list of files that were processed + * in the latest modification time in the previous call to this method. This is because the + * modification time returned by the FileStatus API seems to return times only at the + * granularity of seconds. Hence, new files may have the same modification time as the + * latest modification time in the previous call to this method and the list of files + * maintained is used to filter the one that have been processed. + */ override def compute(validTime: Time): Option[RDD[(K, V)]] = { + // Create the filter for selecting new files val newFilter = new PathFilter() { var latestModTime = 0L - + val latestModTimeFiles = new HashSet[String]() + def accept(path: Path): Boolean = { if (!filter.accept(path)) { return false } else { val modTime = fs.getFileStatus(path).getModificationTime() - if (modTime <= lastModTime) { + if (modTime < lastModTime){ + return false + } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { return false } if (modTime > latestModTime) { latestModTime = modTime + latestModTimeFiles.clear() } + latestModTimeFiles += path.toString return true } } @@ -64,7 +81,12 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K val newFiles = fs.listStatus(path, newFilter) logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) if (newFiles.length > 0) { - lastModTime = newFilter.latestModTime + // Update the modification time and the files processed for that modification time + if (lastModTime != newFilter.latestModTime) { + lastModTime = newFilter.latestModTime + lastModTimeFiles.clear() + } + lastModTimeFiles ++= newFilter.latestModTimeFiles } val newRDD = new UnionRDD(ssc.sc, newFiles.map( file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 8f892baab1..0957748603 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -9,12 +9,19 @@ import spark.storage.StorageLevel import spark.Logging import scala.util.Random import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfter -class InputStreamsSuite extends TestSuiteBase { +class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + override def checkpointDir = "checkpoint" + + after { + FileUtils.deleteDirectory(new File(checkpointDir)) + } + test("network input stream") { // Start the server val serverPort = 9999 @@ -30,7 +37,7 @@ class InputStreamsSuite extends TestSuiteBase { ssc.registerOutputStream(outputStream) ssc.start() - // Feed data to the server to send to the Spark Streaming network receiver + // Feed data to the server to send to the network receiver val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq(1, 2, 3, 4, 5) val expectedOutput = input.map(_.toString) @@ -52,7 +59,7 @@ class InputStreamsSuite extends TestSuiteBase { logInfo("Stopping context") ssc.stop() - // Verify whether data received by Spark Streaming was as expected + // Verify whether data received was as expected logInfo("--------------------------------") logInfo("output.size = " + outputBuffer.size) logInfo("output") @@ -69,6 +76,49 @@ class InputStreamsSuite extends TestSuiteBase { } } + test("network input stream with checkpoint") { + // Start the server + val serverPort = 9999 + val server = new TestServer(9999) + server.start() + + // Set up the streaming context and input streams + var ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + ssc.checkpoint(checkpointDir, checkpointInterval) + val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.MEMORY_AND_DISK) + var outputStream = new TestOutputStream(networkStream, new ArrayBuffer[Seq[String]]) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Feed data to the server to send to the network receiver + var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + for (i <- Seq(1, 2, 3)) { + server.send(i.toString + "\n") + Thread.sleep(100) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(500) + assert(outputStream.output.size > 0) + ssc.stop() + + // Restart stream computation from checkpoint and feed more data to see whether + // they are being received and processed + logInfo("*********** RESTARTING ************") + ssc = new StreamingContext(checkpointDir) + ssc.start() + clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + for (i <- Seq(4, 5, 6)) { + server.send(i.toString + "\n") + Thread.sleep(100) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(500) + outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]] + assert(outputStream.output.size > 0) + ssc.stop() + } + test("file input stream") { // Create a temporary directory val dir = { @@ -76,7 +126,7 @@ class InputStreamsSuite extends TestSuiteBase { temp.delete() temp.mkdirs() temp.deleteOnExit() - println("Created temp dir " + temp) + logInfo("Created temp dir " + temp) temp } @@ -84,7 +134,9 @@ class InputStreamsSuite extends TestSuiteBase { val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) val filestream = ssc.textFileStream(dir.toString) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + def output = outputBuffer.flatMap(x => x) + val outputStream = new TestOutputStream(filestream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() @@ -96,36 +148,88 @@ class InputStreamsSuite extends TestSuiteBase { Thread.sleep(1000) for (i <- 0 until input.size) { FileUtils.writeStringToFile(new File(dir, i.toString), input(i).toString + "\n") - Thread.sleep(500) + Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) - Thread.sleep(500) + Thread.sleep(100) } val startTime = System.currentTimeMillis() - while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - println("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size) + while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + //println("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) Thread.sleep(100) } Thread.sleep(1000) val timeTaken = System.currentTimeMillis() - startTime assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - println("Stopping context") + logInfo("Stopping context") ssc.stop() // Verify whether data received by Spark Streaming was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) + logInfo("output.size = " + output.size) logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + output.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") - assert(outputBuffer.size === expectedOutput.size) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - assert(outputBuffer(i).head === expectedOutput(i)) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i).size === 1) + assert(output(i).head.toString === expectedOutput(i)) + } + } + + test("file input stream with checkpoint") { + // Create a temporary directory + val dir = { + var temp = File.createTempFile(".temp.", Random.nextInt().toString) + temp.delete() + temp.mkdirs() + temp.deleteOnExit() + println("Created temp dir " + temp) + temp } + + // Set up the streaming context and input streams + var ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + ssc.checkpoint(checkpointDir, checkpointInterval) + val filestream = ssc.textFileStream(dir.toString) + var outputStream = new TestOutputStream(filestream, new ArrayBuffer[Seq[String]]) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Create files and advance manual clock to process them + var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + Thread.sleep(1000) + for (i <- Seq(1, 2, 3)) { + FileUtils.writeStringToFile(new File(dir, i.toString), i.toString + "\n") + Thread.sleep(100) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(500) + logInfo("Output = " + outputStream.output.mkString(",")) + assert(outputStream.output.size > 0) + ssc.stop() + + // Restart stream computation from checkpoint and create more files to see whether + // they are being processed + logInfo("*********** RESTARTING ************") + ssc = new StreamingContext(checkpointDir) + ssc.start() + clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + Thread.sleep(500) + for (i <- Seq(4, 5, 6)) { + FileUtils.writeStringToFile(new File(dir, i.toString), i.toString + "\n") + Thread.sleep(100) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(500) + outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]] + logInfo("Output = " + outputStream.output.mkString(",")) + assert(outputStream.output.size > 0) + ssc.stop() } } -- cgit v1.2.3 From ae61ebaee64fad117155d65bcdfc8520bda0e6b4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 12 Nov 2012 21:45:16 +0000 Subject: Fixed bugs in RawNetworkInputDStream and in its examples. Made the ReducedWindowedDStream persist RDDs to MEMOERY_SER_ONLY by default. Removed unncessary examples. Added streaming-env.sh.template to add recommended setting for streaming. --- conf/streaming-env.sh.template | 22 ++++ run | 4 + startTrigger.sh | 3 - .../spark/streaming/NetworkInputTracker.scala | 1 + .../scala/spark/streaming/RawInputDStream.scala | 2 +- .../spark/streaming/ReducedWindowedDStream.scala | 15 ++- .../scala/spark/streaming/StreamingContext.scala | 4 +- .../scala/spark/streaming/examples/Grep2.scala | 64 ------------ .../scala/spark/streaming/examples/GrepRaw.scala | 11 +- .../streaming/examples/TopKWordCountRaw.scala | 102 ++++++------------ .../spark/streaming/examples/WordCount2.scala | 114 --------------------- .../spark/streaming/examples/WordCountRaw.scala | 57 +++++------ .../scala/spark/streaming/examples/WordMax2.scala | 75 -------------- .../scala/spark/streaming/util/RawTextHelper.scala | 98 ++++++++++++++++++ 14 files changed, 193 insertions(+), 379 deletions(-) create mode 100755 conf/streaming-env.sh.template delete mode 100755 startTrigger.sh delete mode 100644 streaming/src/main/scala/spark/streaming/examples/Grep2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCount2.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordMax2.scala create mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala diff --git a/conf/streaming-env.sh.template b/conf/streaming-env.sh.template new file mode 100755 index 0000000000..6b4094c515 --- /dev/null +++ b/conf/streaming-env.sh.template @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +# This file contains a few additional setting that are useful for +# running streaming jobs in Spark. Copy this file as streaming-env.sh . +# Note that this shell script will be read after spark-env.sh, so settings +# in this file may override similar settings (if present) in spark-env.sh . + + +# Using concurrent GC is strongly recommended as it can significantly +# reduce GC related pauses. + +SPARK_JAVA_OPTS+=" -XX:+UseConcMarkSweepGC" + +# Using of Kryo serialization can improve serialization performance +# and therefore the throughput of the Spark Streaming programs. However, +# using Kryo serialization with custom classes may required you to +# register the classes with Kryo. Refer to the Spark documentation +# for more details. + +# SPARK_JAVA_OPTS+=" -Dspark.serializer=spark.KryoSerializer" + +export SPARK_JAVA_OPTS diff --git a/run b/run index a363599cf0..d91430ad2e 100755 --- a/run +++ b/run @@ -13,6 +13,10 @@ if [ -e $FWDIR/conf/spark-env.sh ] ; then . $FWDIR/conf/spark-env.sh fi +if [ -e $FWDIR/conf/streaming-env.sh ] ; then + . $FWDIR/conf/streaming-env.sh +fi + if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then if [ `command -v scala` ]; then RUNNER="scala" diff --git a/startTrigger.sh b/startTrigger.sh deleted file mode 100755 index 373dbda93e..0000000000 --- a/startTrigger.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -./run spark.streaming.SentenceGenerator localhost 7078 sentences.txt 1 diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 07ef79415d..d0fef70f7e 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -47,6 +47,7 @@ class NetworkInputTracker( val result = queue.synchronized { queue.dequeueAll(x => true) } + logInfo("Stream " + receiverId + " received " + result.size + " blocks") result.toArray } diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index e022b85fbe..03726bfba6 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -69,7 +69,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S } def onStop() { - blockPushingThread.interrupt() + if (blockPushingThread != null) blockPushingThread.interrupt() } /** Read a buffer fully from a given Channel */ diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 6df82c0df3..b07d51fa6b 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -31,10 +31,14 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" ) - super.persist(StorageLevel.MEMORY_ONLY) - + // Reduce each batch of data using reduceByKey which will be further reduced by window + // by ReducedWindowedDStream val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + // Persist RDDs to memory by default as these RDDs are going to be reused. + super.persist(StorageLevel.MEMORY_ONLY_SER) + reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) + def windowTime: Time = _windowTime override def dependencies = List(reducedStream) @@ -57,13 +61,6 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( this } - protected[streaming] override def setRememberDuration(time: Time) { - if (rememberDuration == null || rememberDuration < time) { - rememberDuration = time - dependencies.foreach(_.setRememberDuration(rememberDuration + windowTime)) - } - } - override def compute(validTime: Time): Option[RDD[(K, V)]] = { val reduceF = reduceFunc val invReduceF = invReduceFunc diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index eb83aaee7a..ab6d6e8dea 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -124,7 +124,7 @@ final class StreamingContext ( def rawNetworkStream[T: ClassManifest]( hostname: String, port: Int, - storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[T] = { val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) graph.addInputStream(inputStream) @@ -132,7 +132,7 @@ final class StreamingContext ( } /** - * This function creates a input stream that monitors a Hadoop-compatible + * This function creates a input stream that monitors a Hadoop-compatible filesystem * for new files and executes the necessary processing on them. */ def fileStream[ diff --git a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala b/streaming/src/main/scala/spark/streaming/examples/Grep2.scala deleted file mode 100644 index b1faa65c17..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/Grep2.scala +++ /dev/null @@ -1,64 +0,0 @@ -package spark.streaming.examples - -import spark.SparkContext -import SparkContext._ -import spark.streaming._ -import StreamingContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - - -object Grep2 { - - def warmup(sc: SparkContext) { - (0 until 10).foreach {i => - sc.parallelize(1 to 20000000, 1000) - .map(x => (x % 337, x % 1331)) - .reduceByKey(_ + _) - .count() - } - } - - def main (args: Array[String]) { - - if (args.length != 6) { - println ("Usage: Grep2 ") - System.exit(1) - } - - val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args - - val batchDuration = Milliseconds(batchMillis.toLong) - - val ssc = new StreamingContext(master, "Grep2") - ssc.setBatchDuration(batchDuration) - - //warmup(ssc.sc) - - val data = ssc.sc.textFile(file, mapTasks.toInt).persist( - new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas - println("Data count: " + data.count()) - println("Data count: " + data.count()) - println("Data count: " + data.count()) - - val sentences = new ConstantInputDStream(ssc, data) - ssc.registerInputStream(sentences) - - sentences.filter(_.contains("Culpepper")).count().foreachRDD(r => - println("Grep count: " + r.collect().mkString)) - - ssc.start() - - while(true) { Thread.sleep(1000) } - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index b1e1a613fe..ffbea6e55d 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -2,8 +2,10 @@ package spark.streaming.examples import spark.util.IntParam import spark.storage.StorageLevel + import spark.streaming._ import spark.streaming.StreamingContext._ +import spark.streaming.util.RawTextHelper._ object GrepRaw { def main(args: Array[String]) { @@ -17,16 +19,13 @@ object GrepRaw { // Create the context and set the batch size val ssc = new StreamingContext(master, "GrepRaw") ssc.setBatchDuration(Milliseconds(batchMillis)) + warmUp(ssc.sc) - // Make sure some tasks have started on each node - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() val rawStreams = (1 to numStreams).map(_ => - ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = new UnionDStream(rawStreams) - union.filter(_.contains("Culpepper")).count().foreachRDD(r => + union.filter(_.contains("Alice")).count().foreachRDD(r => println("Grep count: " + r.collect().mkString)) ssc.start() } diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 750cb7445f..0411bde1a7 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -1,94 +1,50 @@ package spark.streaming.examples -import spark.util.IntParam -import spark.SparkContext -import spark.SparkContext._ import spark.storage.StorageLevel +import spark.util.IntParam + import spark.streaming._ import spark.streaming.StreamingContext._ +import spark.streaming.util.RawTextHelper._ -import WordCount2_ExtraFunctions._ +import java.util.UUID object TopKWordCountRaw { - def moreWarmup(sc: SparkContext) { - (0 until 40).foreach {i => - sc.parallelize(1 to 20000000, 1000) - .map(_ % 1331).map(_.toString) - .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - .collect() - } - } - + def main(args: Array[String]) { - if (args.length != 7) { - System.err.println("Usage: TopKWordCountRaw ") + if (args.length != 4) { + System.err.println("Usage: WordCountRaw <# streams> ") System.exit(1) } - val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs), - IntParam(chkptMs), IntParam(reduces)) = args - - // Create the context and set the batch size - val ssc = new StreamingContext(master, "TopKWordCountRaw") - ssc.setBatchDuration(Milliseconds(batchMs)) - - // Make sure some tasks have started on each node - moreWarmup(ssc.sc) - - val rawStreams = (1 to streams).map(_ => - ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray - val union = new UnionDStream(rawStreams) - - val windowedCounts = union.mapPartitions(splitAndCountPartitions) - .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist().checkpoint(Milliseconds(chkptMs)) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) - - def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { - val taken = new Array[(String, Long)](k) - - var i = 0 - var len = 0 - var done = false - var value: (String, Long) = null - var swap: (String, Long) = null - var count = 0 - - while(data.hasNext) { - value = data.next - count += 1 - println("count = " + count) - if (len == 0) { - taken(0) = value - len = 1 - } else if (len < k || value._2 > taken(len - 1)._2) { - if (len < k) { - len += 1 - } - taken(len - 1) = value - i = len - 1 - while(i > 0 && taken(i - 1)._2 < taken(i)._2) { - swap = taken(i) - taken(i) = taken(i-1) - taken(i - 1) = swap - i -= 1 - } - } - } - println("Took " + len + " out of " + count + " items") - return taken.toIterator - } + val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args + val k = 10 - val k = 50 + // Create the context, set the batch size and checkpoint directory. + // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts + // periodically to HDFS + val ssc = new StreamingContext(master, "TopKWordCountRaw") + ssc.setBatchDuration(Seconds(1)) + ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + /*warmUp(ssc.sc)*/ + + // Set up the raw network streams that will connect to localhost:port to raw test + // senders on the slaves and generate top K words of last 30 seconds + val lines = (1 to numStreams).map(_ => { + ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) + }) + val union = new UnionDStream(lines.toArray) + val counts = union.mapPartitions(splitAndCountPartitions) + val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) partialTopKWindowedCounts.foreachRDD(rdd => { val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " items") - topK(collectedCounts.toIterator, k).foreach(println) + println("Collected " + collectedCounts.size + " words from partial top words") + println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(",")) }) -// windowedCounts.foreachRDD(r => println("Element count: " + r.count())) - ssc.start() } } diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala deleted file mode 100644 index 865026033e..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ /dev/null @@ -1,114 +0,0 @@ -package spark.streaming.examples - -import spark.SparkContext -import SparkContext._ -import spark.streaming._ -import StreamingContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - - -object WordCount2_ExtraFunctions { - - def add(v1: Long, v2: Long) = (v1 + v2) - - def subtract(v1: Long, v2: Long) = (v1 - v2) - - def max(v1: Long, v2: Long) = math.max(v1, v2) - - def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { - //val map = new java.util.HashMap[String, Long] - val map = new OLMap[String] - var i = 0 - var j = 0 - while (iter.hasNext) { - val s = iter.next() - i = 0 - while (i < s.length) { - j = i - while (j < s.length && s.charAt(j) != ' ') { - j += 1 - } - if (j > i) { - val w = s.substring(i, j) - val c = map.getLong(w) - map.put(w, c + 1) -/* - if (c == null) { - map.put(w, 1) - } else { - map.put(w, c + 1) - } -*/ - } - i = j - while (i < s.length && s.charAt(i) == ' ') { - i += 1 - } - } - } - map.toIterator.map{case (k, v) => (k, v)} - } -} - -object WordCount2 { - - def warmup(sc: SparkContext) { - (0 until 3).foreach {i => - sc.parallelize(1 to 20000000, 500) - .map(x => (x % 337, x % 1331)) - .reduceByKey(_ + _, 100) - .count() - } - } - - def main (args: Array[String]) { - - if (args.length != 6) { - println ("Usage: WordCount2 ") - System.exit(1) - } - - val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args - - val batchDuration = Milliseconds(batchMillis.toLong) - - val ssc = new StreamingContext(master, "WordCount2") - ssc.setBatchDuration(batchDuration) - - //warmup(ssc.sc) - - val data = ssc.sc.textFile(file, mapTasks.toInt).persist( - new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas - println("Data count: " + data.map(x => if (x == "") 1 else x.split(" ").size / x.split(" ").size).count()) - println("Data count: " + data.count()) - println("Data count: " + data.count()) - - val sentences = new ConstantInputDStream(ssc, data) - ssc.registerInputStream(sentences) - - import WordCount2_ExtraFunctions._ - - val windowedCounts = sentences - .mapPartitions(splitAndCountPartitions) - .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) - - windowedCounts.persist().checkpoint(Milliseconds(chkptMillis.toLong)) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) - windowedCounts.foreachRDD(r => println("Element count: " + r.count())) - - ssc.start() - - while(true) { Thread.sleep(1000) } - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index d1ea9a9cd5..571428c0fe 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -1,50 +1,43 @@ package spark.streaming.examples -import spark.util.IntParam -import spark.SparkContext -import spark.SparkContext._ import spark.storage.StorageLevel +import spark.util.IntParam + import spark.streaming._ import spark.streaming.StreamingContext._ +import spark.streaming.util.RawTextHelper._ -import WordCount2_ExtraFunctions._ +import java.util.UUID object WordCountRaw { - def moreWarmup(sc: SparkContext) { - (0 until 40).foreach {i => - sc.parallelize(1 to 20000000, 1000) - .map(_ % 1331).map(_.toString) - .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) - .collect() - } - } def main(args: Array[String]) { - if (args.length != 7) { - System.err.println("Usage: WordCountRaw ") + if (args.length != 4) { + System.err.println("Usage: WordCountRaw <# streams> ") System.exit(1) } - val Array(master, IntParam(streams), host, IntParam(port), IntParam(batchMs), - IntParam(chkptMs), IntParam(reduces)) = args + val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - // Create the context and set the batch size + // Create the context, set the batch size and checkpoint directory. + // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts + // periodically to HDFS val ssc = new StreamingContext(master, "WordCountRaw") - ssc.setBatchDuration(Milliseconds(batchMs)) - - // Make sure some tasks have started on each node - moreWarmup(ssc.sc) - - val rawStreams = (1 to streams).map(_ => - ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray - val union = new UnionDStream(rawStreams) - - val windowedCounts = union.mapPartitions(splitAndCountPartitions) - .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist().checkpoint(chkptMs) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) - - windowedCounts.foreachRDD(r => println("Element count: " + r.count())) + ssc.setBatchDuration(Seconds(1)) + ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + warmUp(ssc.sc) + + // Set up the raw network streams that will connect to localhost:port to raw test + // senders on the slaves and generate count of words of last 30 seconds + val lines = (1 to numStreams).map(_ => { + ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) + }) + val union = new UnionDStream(lines.toArray) + val counts = union.mapPartitions(splitAndCountPartitions) + val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) + windowedCounts.foreachRDD(r => println("# unique words = " + r.count())) ssc.start() } diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala deleted file mode 100644 index 6a9c8a9a69..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ /dev/null @@ -1,75 +0,0 @@ -package spark.streaming.examples - -import spark.SparkContext -import SparkContext._ -import spark.streaming._ -import StreamingContext._ - -import spark.storage.StorageLevel - -import scala.util.Sorting -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.collection.JavaConversions.mapAsScalaMap - -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - - -object WordMax2 { - - def warmup(sc: SparkContext) { - (0 until 10).foreach {i => - sc.parallelize(1 to 20000000, 1000) - .map(x => (x % 337, x % 1331)) - .reduceByKey(_ + _) - .count() - } - } - - def main (args: Array[String]) { - - if (args.length != 6) { - println ("Usage: WordMax2 ") - System.exit(1) - } - - val Array(master, file, mapTasks, reduceTasks, batchMillis, chkptMillis) = args - - val batchDuration = Milliseconds(batchMillis.toLong) - - val ssc = new StreamingContext(master, "WordMax2") - ssc.setBatchDuration(batchDuration) - - //warmup(ssc.sc) - - val data = ssc.sc.textFile(file, mapTasks.toInt).persist( - new StorageLevel(false, true, false, 3)) // Memory only, serialized, 3 replicas - println("Data count: " + data.count()) - println("Data count: " + data.count()) - println("Data count: " + data.count()) - - val sentences = new ConstantInputDStream(ssc, data) - ssc.registerInputStream(sentences) - - import WordCount2_ExtraFunctions._ - - val windowedCounts = sentences - .mapPartitions(splitAndCountPartitions) - .reduceByKey(add _, reduceTasks.toInt) - .persist() - .checkpoint(Milliseconds(chkptMillis.toLong)) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) - .reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt) - .persist() - .checkpoint(Milliseconds(chkptMillis.toLong)) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) - windowedCounts.foreachRDD(r => println("Element count: " + r.count())) - - ssc.start() - - while(true) { Thread.sleep(1000) } - } -} - - diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala new file mode 100644 index 0000000000..f31ae39a16 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala @@ -0,0 +1,98 @@ +package spark.streaming.util + +import spark.SparkContext +import spark.SparkContext._ +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} +import scala.collection.JavaConversions.mapAsScalaMap + +object RawTextHelper { + + /** + * Splits lines and counts the words in them using specialized object-to-long hashmap + * (to avoid boxing-unboxing overhead of Long in java/scala HashMap) + */ + def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { + val map = new OLMap[String] + var i = 0 + var j = 0 + while (iter.hasNext) { + val s = iter.next() + i = 0 + while (i < s.length) { + j = i + while (j < s.length && s.charAt(j) != ' ') { + j += 1 + } + if (j > i) { + val w = s.substring(i, j) + val c = map.getLong(w) + map.put(w, c + 1) + } + i = j + while (i < s.length && s.charAt(i) == ' ') { + i += 1 + } + } + } + map.toIterator.map{case (k, v) => (k, v)} + } + + /** + * Gets the top k words in terms of word counts. Assumes that each word exists only once + * in the `data` iterator (that is, the counts have been reduced). + */ + def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { + val taken = new Array[(String, Long)](k) + + var i = 0 + var len = 0 + var done = false + var value: (String, Long) = null + var swap: (String, Long) = null + var count = 0 + + while(data.hasNext) { + value = data.next + if (value != null) { + count += 1 + if (len == 0) { + taken(0) = value + len = 1 + } else if (len < k || value._2 > taken(len - 1)._2) { + if (len < k) { + len += 1 + } + taken(len - 1) = value + i = len - 1 + while(i > 0 && taken(i - 1)._2 < taken(i)._2) { + swap = taken(i) + taken(i) = taken(i-1) + taken(i - 1) = swap + i -= 1 + } + } + } + } + return taken.toIterator + } + + /** + * Warms up the SparkContext in master and slave by running tasks to force JIT kick in + * before real workload starts. + */ + def warmUp(sc: SparkContext) { + for(i <- 0 to 4) { + sc.parallelize(1 to 200000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + .count() + } + } + + def add(v1: Long, v2: Long) = (v1 + v2) + + def subtract(v1: Long, v2: Long) = (v1 - v2) + + def max(v1: Long, v2: Long) = math.max(v1, v2) +} + -- cgit v1.2.3 From b9bfd1456f09f4db281fb9d108a339c59a2e2dda Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 12 Nov 2012 21:51:42 +0000 Subject: Changed default level on calling DStream.persist() to be MEMORY_ONLY_SER. Also changed the persist level of StateDStream to be MEMORY_ONLY_SER. --- streaming/src/main/scala/spark/streaming/DStream.scala | 2 +- streaming/src/main/scala/spark/streaming/StateDStream.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 73096edec5..abf132e45e 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -79,7 +79,7 @@ extends Serializable with Logging { this } - def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY) + def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_SER) // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index 0211df1343..cb261808f5 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -23,7 +23,7 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { - super.persist(StorageLevel.MEMORY_ONLY) + super.persist(StorageLevel.MEMORY_ONLY_SER) override def dependencies = List(parent) -- cgit v1.2.3 From 564dd8c3f415746a68f05bde6ea2a0e7a7760b4c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 12 Nov 2012 14:22:05 -0800 Subject: Speeded up CheckpointSuite --- .../scala/spark/streaming/CheckpointSuite.scala | 26 ++++++++++++---------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 038827ddb0..0ad57e38b9 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -15,12 +15,16 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } after { + + if (ssc != null) ssc.stop() FileUtils.deleteDirectory(new File(checkpointDir)) } + var ssc: StreamingContext = null + override def framework = "CheckpointSuite" - override def batchDuration = Milliseconds(500) + override def batchDuration = Milliseconds(200) override def checkpointDir = "checkpoint" @@ -30,12 +34,12 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { test("basic stream+rdd recovery") { - assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") + assert(batchDuration === Milliseconds(200), "batchDuration for this test must be 1 second") assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration") System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - val stateStreamCheckpointInterval = Seconds(2) + val stateStreamCheckpointInterval = Seconds(1) // this ensure checkpointing occurs at least once val firstNumBatches = (stateStreamCheckpointInterval.millis / batchDuration.millis) * 2 @@ -110,6 +114,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { runStreamsWithRealDelay(ssc, 4) ssc.stop() System.clearProperty("spark.streaming.manualClock.jump") + ssc = null } test("map and reduceByKey") { @@ -131,9 +136,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) .checkpoint(Seconds(2)) } - for (i <- Seq(2, 3, 4)) { - testCheckpointedOperation(input, operation, output, i) - } + testCheckpointedOperation(input, operation, output, 3) } test("updateStateByKey") { @@ -148,9 +151,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { .checkpoint(Seconds(2)) .map(t => (t._1, t._2.self)) } - for (i <- Seq(2, 3, 4)) { - testCheckpointedOperation(input, operation, output, i) - } + testCheckpointedOperation(input, operation, output, 3) } @@ -171,7 +172,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Do half the computation (half the number of batches), create checkpoint file and quit - val ssc = setupStreams[U, V](input, operation) + ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) Thread.sleep(1000) @@ -182,9 +183,10 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { " Restarting stream computation " + "\n-------------------------------------------\n" ) - val sscNew = new StreamingContext(checkpointDir) - val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) + ssc = new StreamingContext(checkpointDir) + val outputNew = runStreams[V](ssc, nextNumBatches, nextNumExpectedOutputs) verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + ssc = null } /** -- cgit v1.2.3 From 8a25d530edfa3abcdbe2effcd6bfbe484ac40acb Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 13 Nov 2012 02:16:28 -0800 Subject: Optimized checkpoint writing by reusing FileSystem object. Fixed bug in updating of checkpoint data in DStream where the checkpointed RDDs, upon recovery, were not recognized as checkpointed RDDs and therefore deleted from HDFS. Made InputStreamsSuite more robust to timing delays. --- core/src/main/scala/spark/RDD.scala | 6 +- .../main/scala/spark/streaming/Checkpoint.scala | 73 ++++++++++-------- .../src/main/scala/spark/streaming/DStream.scala | 8 +- .../src/main/scala/spark/streaming/Scheduler.scala | 28 +++++-- .../scala/spark/streaming/StreamingContext.scala | 10 +-- streaming/src/test/resources/log4j.properties | 2 +- .../scala/spark/streaming/CheckpointSuite.scala | 25 +++--- .../scala/spark/streaming/InputStreamsSuite.scala | 88 ++++++++++++---------- 8 files changed, 129 insertions(+), 111 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 63048d5df0..6af8c377b5 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -189,11 +189,7 @@ abstract class RDD[T: ClassManifest]( def getCheckpointData(): Any = { synchronized { - if (isCheckpointed) { - checkpointFile - } else { - null - } + checkpointFile } } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index a70fb8f73a..770f7b0cc0 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -5,7 +5,7 @@ import spark.{Logging, Utils} import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration -import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} +import java.io._ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) @@ -18,8 +18,6 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val checkpointDir = ssc.checkpointDir val checkpointInterval = ssc.checkpointInterval - validate() - def validate() { assert(master != null, "Checkpoint.master is null") assert(framework != null, "Checkpoint.framework is null") @@ -27,35 +25,50 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) assert(checkpointTime != null, "Checkpoint.checkpointTime is null") logInfo("Checkpoint for time " + checkpointTime + " validated") } +} - def save(path: String) { - val file = new Path(path, "graph") - val conf = new Configuration() - val fs = file.getFileSystem(conf) - logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") - if (fs.exists(file)) { - val bkFile = new Path(file.getParent, file.getName + ".bk") - FileUtil.copy(fs, file, fs, bkFile, true, true, conf) - logDebug("Moved existing checkpoint file to " + bkFile) +/** + * Convenience class to speed up the writing of graph checkpoint to file + */ +class CheckpointWriter(checkpointDir: String) extends Logging { + val file = new Path(checkpointDir, "graph") + val conf = new Configuration() + var fs = file.getFileSystem(conf) + val maxAttempts = 3 + + def write(checkpoint: Checkpoint) { + // TODO: maybe do this in a different thread from the main stream execution thread + var attempts = 0 + while (attempts < maxAttempts) { + attempts += 1 + try { + logDebug("Saving checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) + } + val fos = fs.create(file) + val oos = new ObjectOutputStream(fos) + oos.writeObject(checkpoint) + oos.close() + logInfo("Checkpoint for time " + checkpoint.checkpointTime + " saved to file '" + file + "'") + fos.close() + return + } catch { + case ioe: IOException => + logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) + } } - val fos = fs.create(file) - val oos = new ObjectOutputStream(fos) - oos.writeObject(this) - oos.close() - fs.close() - logInfo("Checkpoint of streaming context for time " + checkpointTime + " saved successfully to file '" + file + "'") - } - - def toBytes(): Array[Byte] = { - val bytes = Utils.serialize(this) - bytes + logError("Could not write checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") } } -object Checkpoint extends Logging { - def load(path: String): Checkpoint = { +object CheckpointReader extends Logging { + + def read(path: String): Checkpoint = { val fs = new Path(path).getFileSystem(new Configuration()) val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) @@ -82,17 +95,11 @@ object Checkpoint extends Logging { logError("Error loading checkpoint from file '" + file + "'", e) } } else { - logWarning("Could not load checkpoint from file '" + file + "' as it does not exist") + logWarning("Could not read checkpoint from file '" + file + "' as it does not exist") } }) - throw new Exception("Could not load checkpoint from path '" + path + "'") - } - - def fromBytes(bytes: Array[Byte]): Checkpoint = { - val cp = Utils.deserialize[Checkpoint](bytes) - cp.validate() - cp + throw new Exception("Could not read checkpoint from path '" + path + "'") } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index abf132e45e..7e6f73dd7d 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -289,6 +289,7 @@ extends Serializable with Logging { */ protected[streaming] def updateCheckpointData(currentTime: Time) { logInfo("Updating checkpoint data for time " + currentTime) + // Get the checkpointed RDDs from the generated RDDs val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) .map(x => (x._1, x._2.getCheckpointData())) @@ -334,8 +335,11 @@ extends Serializable with Logging { logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") checkpointData.foreach { case(time, data) => { - logInfo("Restoring checkpointed RDD for time " + time + " from file") - generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) + logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") + val rdd = ssc.sc.objectFile[T](data.toString) + // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData() + rdd.checkpointFile = data.toString + generatedRDDs += ((time, rdd)) } } dependencies.foreach(_.restoreCheckpointData()) diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index de0fb1f3ad..e2dca91179 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -16,8 +16,16 @@ extends Logging { initLogging() val graph = ssc.graph + val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) + + val checkpointWriter = if (ssc.checkpointInterval != null && ssc.checkpointDir != null) { + new CheckpointWriter(ssc.checkpointDir) + } else { + null + } + val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.graph.batchDuration, generateRDDs(_)) @@ -52,19 +60,23 @@ extends Logging { logInfo("Scheduler stopped") } - def generateRDDs(time: Time) { + private def generateRDDs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") - graph.generateRDDs(time).foreach(submitJob) - logInfo("Generated RDDs for time " + time) + graph.generateRDDs(time).foreach(jobManager.runJob) graph.forgetOldRDDs(time) - if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { - ssc.doCheckpoint(time) - } + doCheckpoint(time) + logInfo("Generated RDDs for time " + time) } - def submitJob(job: Job) { - jobManager.runJob(job) + private def doCheckpoint(time: Time) { + if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { + val startTime = System.currentTimeMillis() + ssc.graph.updateCheckpointData(time) + checkpointWriter.write(new Checkpoint(ssc, time)) + val stopTime = System.currentTimeMillis() + logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms") + } } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index ab6d6e8dea..ef6a05a392 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -28,7 +28,7 @@ final class StreamingContext ( def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = this(new SparkContext(master, frameworkName, sparkHome, jars), null) - def this(path: String) = this(null, Checkpoint.load(path)) + def this(path: String) = this(null, CheckpointReader.read(path)) def this(cp_ : Checkpoint) = this(null, cp_) @@ -225,14 +225,6 @@ final class StreamingContext ( case e: Exception => logWarning("Error while stopping", e) } } - - def doCheckpoint(currentTime: Time) { - val startTime = System.currentTimeMillis() - graph.updateCheckpointData(currentTime) - new Checkpoint(this, currentTime).save(checkpointDir) - val stopTime = System.currentTimeMillis() - logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms") - } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 02fe16866e..33774b463d 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,5 +1,5 @@ # Set everything to be logged to the console -log4j.rootCategory=WARN, console +log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 0ad57e38b9..b3afedf39f 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -24,7 +24,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { override def framework = "CheckpointSuite" - override def batchDuration = Milliseconds(200) + override def batchDuration = Milliseconds(500) override def checkpointDir = "checkpoint" @@ -34,7 +34,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { test("basic stream+rdd recovery") { - assert(batchDuration === Milliseconds(200), "batchDuration for this test must be 1 second") + assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration") System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") @@ -134,9 +134,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) - .checkpoint(Seconds(2)) + .checkpoint(batchDuration * 2) } - testCheckpointedOperation(input, operation, output, 3) + testCheckpointedOperation(input, operation, output, 7) } test("updateStateByKey") { @@ -148,14 +148,18 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } st.map(x => (x, 1)) .updateStateByKey[RichInt](updateFunc) - .checkpoint(Seconds(2)) + .checkpoint(batchDuration * 2) .map(t => (t._1, t._2.self)) } - testCheckpointedOperation(input, operation, output, 3) + testCheckpointedOperation(input, operation, output, 7) } - - + /** + * Tests a streaming operation under checkpointing, by restart the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + */ def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], @@ -170,8 +174,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val initialNumExpectedOutputs = initialNumBatches val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs - // Do half the computation (half the number of batches), create checkpoint file and quit - + // Do the computation for initial number of batches, create checkpoint file and quit ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) @@ -193,8 +196,6 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { * Advances the manual clock on the streaming scheduler by given number of batches. * It also wait for the expected amount of time for each batch. */ - - def runStreamsWithRealDelay(ssc: StreamingContext, numBatches: Long) { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 0957748603..3e99440226 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -16,24 +16,36 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + val testPort = 9999 + var testServer: TestServer = null + var testDir: File = null + override def checkpointDir = "checkpoint" after { FileUtils.deleteDirectory(new File(checkpointDir)) + if (testServer != null) { + testServer.stop() + testServer = null + } + if (testDir != null && testDir.exists()) { + FileUtils.deleteDirectory(testDir) + testDir = null + } } test("network input stream") { // Start the server - val serverPort = 9999 - val server = new TestServer(9999) - server.start() + testServer = new TestServer(testPort) + testServer.start() // Set up the streaming context and input streams val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.MEMORY_AND_DISK) + val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] val outputStream = new TestOutputStream(networkStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) ssc.registerOutputStream(outputStream) ssc.start() @@ -41,21 +53,15 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq(1, 2, 3, 4, 5) val expectedOutput = input.map(_.toString) + Thread.sleep(1000) for (i <- 0 until input.size) { - server.send(input(i).toString + "\n") + testServer.send(input(i).toString + "\n") Thread.sleep(500) clock.addToTime(batchDuration.milliseconds) } - val startTime = System.currentTimeMillis() - while (outputBuffer.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + outputBuffer.size + ", expectedOutput.size = " + expectedOutput.size) - Thread.sleep(100) - } Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") logInfo("Stopping server") - server.stop() + testServer.stop() logInfo("Stopping context") ssc.stop() @@ -69,24 +75,24 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") - assert(outputBuffer.size === expectedOutput.size) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - assert(outputBuffer(i).head === expectedOutput(i)) + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) } } test("network input stream with checkpoint") { // Start the server - val serverPort = 9999 - val server = new TestServer(9999) - server.start() + testServer = new TestServer(testPort) + testServer.start() // Set up the streaming context and input streams var ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) - val networkStream = ssc.networkTextStream("localhost", serverPort, StorageLevel.MEMORY_AND_DISK) + val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) var outputStream = new TestOutputStream(networkStream, new ArrayBuffer[Seq[String]]) ssc.registerOutputStream(outputStream) ssc.start() @@ -94,7 +100,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Feed data to the server to send to the network receiver var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] for (i <- Seq(1, 2, 3)) { - server.send(i.toString + "\n") + testServer.send(i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } @@ -109,7 +115,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { ssc.start() clock = ssc.scheduler.clock.asInstanceOf[ManualClock] for (i <- Seq(4, 5, 6)) { - server.send(i.toString + "\n") + testServer.send(i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } @@ -120,12 +126,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("file input stream") { + // Create a temporary directory - val dir = { + testDir = { var temp = File.createTempFile(".temp.", Random.nextInt().toString) temp.delete() temp.mkdirs() - temp.deleteOnExit() logInfo("Created temp dir " + temp) temp } @@ -133,10 +139,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Set up the streaming context and input streams val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - val filestream = ssc.textFileStream(dir.toString) + val filestream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] def output = outputBuffer.flatMap(x => x) - val outputStream = new TestOutputStream(filestream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() @@ -147,16 +152,16 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val expectedOutput = input.map(_.toString) Thread.sleep(1000) for (i <- 0 until input.size) { - FileUtils.writeStringToFile(new File(dir, i.toString), input(i).toString + "\n") - Thread.sleep(100) + FileUtils.writeStringToFile(new File(testDir, i.toString), input(i).toString + "\n") + Thread.sleep(500) clock.addToTime(batchDuration.milliseconds) - Thread.sleep(100) + //Thread.sleep(100) } val startTime = System.currentTimeMillis() - while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - //println("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) + /*while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) Thread.sleep(100) - } + }*/ Thread.sleep(1000) val timeTaken = System.currentTimeMillis() - startTime assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") @@ -165,14 +170,16 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received by Spark Streaming was as expected logInfo("--------------------------------") - logInfo("output.size = " + output.size) + logInfo("output.size = " + outputBuffer.size) logInfo("output") - output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) assert(output.size === expectedOutput.size) for (i <- 0 until output.size) { assert(output(i).size === 1) @@ -182,12 +189,11 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { test("file input stream with checkpoint") { // Create a temporary directory - val dir = { + testDir = { var temp = File.createTempFile(".temp.", Random.nextInt().toString) temp.delete() temp.mkdirs() - temp.deleteOnExit() - println("Created temp dir " + temp) + logInfo("Created temp dir " + temp) temp } @@ -195,7 +201,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { var ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) - val filestream = ssc.textFileStream(dir.toString) + val filestream = ssc.textFileStream(testDir.toString) var outputStream = new TestOutputStream(filestream, new ArrayBuffer[Seq[String]]) ssc.registerOutputStream(outputStream) ssc.start() @@ -204,7 +210,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] Thread.sleep(1000) for (i <- Seq(1, 2, 3)) { - FileUtils.writeStringToFile(new File(dir, i.toString), i.toString + "\n") + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } @@ -221,7 +227,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { clock = ssc.scheduler.clock.asInstanceOf[ManualClock] Thread.sleep(500) for (i <- Seq(4, 5, 6)) { - FileUtils.writeStringToFile(new File(dir, i.toString), i.toString + "\n") + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") Thread.sleep(100) clock.addToTime(batchDuration.milliseconds) } -- cgit v1.2.3 From c3ccd14cf8d7c5a867992758b74922890408541e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 13 Nov 2012 02:43:03 -0800 Subject: Replaced StateRDD in StateDStream with MapPartitionsRDD. --- .../src/main/scala/spark/streaming/StateDStream.scala | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index cb261808f5..b7e4c1c30c 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -7,20 +7,11 @@ import spark.rdd.MapPartitionsRDD import spark.SparkContext._ import spark.storage.StorageLevel - -class StateRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: Iterator[T] => Iterator[U], - rememberPartitioner: Boolean - ) extends MapPartitionsRDD[U, T](prev, f) { - override val partitioner = if (rememberPartitioner) prev.partitioner else None -} - class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, - rememberPartitioner: Boolean + preservePartitioning: Boolean ) extends DStream[(K, S)](parent.ssc) { super.persist(StorageLevel.MEMORY_ONLY_SER) @@ -53,7 +44,7 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife updateFuncLocal(i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = new StateRDD(cogroupedRDD, finalFunc, rememberPartitioner) + val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } @@ -78,7 +69,7 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife } val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = new StateRDD(groupedRDD, finalFunc, rememberPartitioner) + val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) //logDebug("Generating state RDD for time " + validTime + " (first)") return Some(sessionRDD) } -- cgit v1.2.3 From 26fec8f0b850e7eb0b6cfe63770f2e68cd50441b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 13 Nov 2012 11:05:57 -0800 Subject: Fixed bug in MappedValuesRDD, and set default graph checkpoint interval to be batch duration. --- .../src/main/scala/spark/streaming/DStream.scala | 2 +- .../spark/streaming/ReducedWindowedDStream.scala | 2 +- .../scala/spark/streaming/StreamingContext.scala | 23 +++++++++++++++++++--- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 7e6f73dd7d..76cdf8c464 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -230,7 +230,7 @@ extends Serializable with Logging { } if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { newRDD.checkpoint() - logInfo("Marking RDD for time " + time + " for checkpointing at time " + time) + logInfo("Marking RDD " + newRDD + " for time " + time + " for checkpointing at time " + time) } generatedRDDs.put(time, newRDD) Some(newRDD) diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index b07d51fa6b..8b484e6acf 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -57,7 +57,7 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( override def checkpoint(interval: Time): DStream[(K, V)] = { super.checkpoint(interval) - reducedStream.checkpoint(interval) + //reducedStream.checkpoint(interval) this } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index ef6a05a392..7a9a71f303 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -85,7 +85,7 @@ final class StreamingContext ( graph.setRememberDuration(duration) } - def checkpoint(dir: String, interval: Time) { + def checkpoint(dir: String, interval: Time = null) { if (dir != null) { sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(dir)) checkpointDir = dir @@ -186,12 +186,29 @@ final class StreamingContext ( graph.addOutputStream(outputStream) } + def validate() { + assert(graph != null, "Graph is null") + graph.validate() + + assert( + checkpointDir == null || checkpointInterval != null, + "Checkpoint directory has been set, but the graph checkpointing interval has " + + "not been set. Please use StreamingContext.checkpoint() to set the interval." + ) + + + } + + /** * This function starts the execution of the streams. */ def start() { - assert(graph != null, "Graph is null") - graph.validate() + if (checkpointDir != null && checkpointInterval == null && graph != null) { + checkpointInterval = graph.batchDuration + } + + validate() val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true -- cgit v1.2.3 From b6f7ba813e93916dad9dbb0f06819362a5fb7cf7 Mon Sep 17 00:00:00 2001 From: Denny Date: Tue, 13 Nov 2012 13:15:32 -0800 Subject: change import for example function --- streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index 1e92cbb210..c85ac8e984 100644 --- a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -3,7 +3,7 @@ package spark.streaming.examples import spark.streaming._ import spark.streaming.StreamingContext._ import spark.storage.StorageLevel -import WordCount2_ExtraFunctions._ +import spark.streaming.util.RawTextHelper._ object KafkaWordCount { def main(args: Array[String]) { -- cgit v1.2.3 From d39ac5fbc1b374b420f0a72125b35ea047b8cabb Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 13 Nov 2012 21:11:55 -0800 Subject: Streaming programming guide. STREAMING-2 #resolve --- docs/_layouts/global.html | 1 + docs/index.md | 1 + docs/streaming-programming-guide.md | 163 ++++++++++++++++++++++++++++++++++++ 3 files changed, 165 insertions(+) create mode 100644 docs/streaming-programming-guide.md diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 7244ab6fc9..d656b3e3de 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -47,6 +47,7 @@
  • Quick Start
  • Scala
  • Java
  • +
  • Spark Streaming (Alpha)
  • diff --git a/docs/index.md b/docs/index.md index ed9953a590..560811ade8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -59,6 +59,7 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Quick Start](quick-start.html): a quick introduction to the Spark API; start here! * [Spark Programming Guide](scala-programming-guide.html): an overview of Spark concepts, and details on the Scala API * [Java Programming Guide](java-programming-guide.html): using Spark from Java +* [Streaming Guide](streaming-programming-guide.html): an API preview of Spark Streaming **Deployment guides:** diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md new file mode 100644 index 0000000000..90916545bc --- /dev/null +++ b/docs/streaming-programming-guide.md @@ -0,0 +1,163 @@ +--- +layout: global +title: Streaming (Alpha) Programming Guide +--- +# Initializing Spark Streaming +The first thing a Spark Streaming program must do is create a `StreamingContext` object, which tells Spark how to access a cluster. A `StreamingContext` can be created from an existing `SparkContext`, or directly: + +{% highlight scala %} +new StreamingContext(master, jobName, [sparkHome], [jars]) +new StreamingContext(sparkContext) +{% endhighlight %} + +Once a context is instantiated, the batch interval must be set: + +{% highlight scala %} +context.setBatchDuration(Milliseconds(2000)) +{% endhighlight %} + + +# DStreams - Discretized Streams +The primary abstraction in Spark Streaming is a DStream. A DStream represents distributed collection which is computed periodically according to a specified batch interval. DStream's can be chained together to create complex chains of transformation on streaming data. DStreams can be created by operating on existing DStreams or from an input source. To creating DStreams from an input source, use the StreamingContext: + +{% highlight scala %} +context.neworkStream(host, port) // A stream that reads from a socket +context.flumeStream(hosts, ports) // A stream populated by a Flume flow +{% endhighlight %} + +# DStream Operators +Once an input stream has been created, you can transform it using _stream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the stream by writing data out to an external source. + +## Transformations + +DStreams support many of the transformations available on normal Spark RDD's: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    TransformationMeaning
    map(func) Return a new stream formed by passing each element of the source through a function func.
    filter(func) Return a new stream formed by selecting those elements of the source on which func returns true.
    flatMap(func) Similar to map, but each input item can be mapped to 0 or more output items (so func should return a Seq rather than a single item).
    mapPartitions(func) Similar to map, but runs separately on each partition (block) of the DStream, so func must be of type + Iterator[T] => Iterator[U] when running on an DStream of type T.
    union(otherStream) Return a new stream that contains the union of the elements in the source stream and the argument.
    groupByKey([numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs.
    +Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. +
    reduceByKey(func, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument.
    join(otherStream, [numTasks]) When called on streams of type (K, V) and (K, W), returns a stream of (K, (V, W)) pairs with all pairs of elements for each key.
    cogroup(otherStream, [numTasks]) When called on streams of type (K, V) and (K, W), returns a stream of (K, Seq[V], Seq[W]) tuples. This operation is also called groupWith.
    + +DStreams also support the following additional transformations: + + + + + + +
    reduce(func) Create a new single-element stream by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel.
    + + +## Windowed Transformations +Spark streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    TransformationMeaning
    window(windowTime, slideTime) Return a new stream which is computed based on windowed batches of the source stream. windowTime is the width of the window and slideTime is the frequency during which the window is calculated. Both times must be multiples of the batch interval. +
    countByWindow(windowTime, slideTime) Return a sliding count of elements in the stream. windowTime and slideTime are exactly as defined in window(). +
    reduceByWindow(func, windowTime, slideTime) Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using func. The function should be associative so that it can be computed correctly in parallel. windowTime and slideTime are exactly as defined in window(). +
    groupByKeyAndWindow(windowTime, slideTime, [numTasks]) + When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs over a sliding window.
    +Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. windowTime and slideTime are exactly as defined in window(). +
    reduceByKeyAndWindow(func, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function over batches within a sliding window. Like in groupByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. + windowTime and slideTime are exactly as defined in window(). +
    countByKeyAndWindow([numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Int) pairs where the values for each key are the count within a sliding window. Like in countByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. + windowTime and slideTime are exactly as defined in window(). +
    + + +## Output Operators +When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    OperatorMeaning
    foreachRDD(func) The fundamental output operator. Applies a function, func, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system.
    print() Prints the contents of this DStream on the driver. At each interval, this will take at most ten elements from the DStream's RDD and print them.
    saveAsObjectFile(prefix, [suffix]) Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". +
    saveAsTextFile(prefix, suffix) Save this DStream's contents as a text files. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    saveAsHadoopFiles(prefix, suffix) Save this DStream's contents as a Hadoop file. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    + -- cgit v1.2.3 From 720cb0f46736105d200128f13081489281dbe118 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 14 Nov 2012 18:08:02 -0800 Subject: A "streaming page view" example. --- .../examples/clickstream/PageViewGenerator.scala | 85 ++++++++++++++++++++++ .../examples/clickstream/PageViewStream.scala | 85 ++++++++++++++++++++++ 2 files changed, 170 insertions(+) create mode 100644 streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala new file mode 100644 index 0000000000..4c6e08bc74 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala @@ -0,0 +1,85 @@ +package spark.streaming.examples.clickstream + +import java.net.{InetAddress,ServerSocket,Socket,SocketException} +import java.io.{InputStreamReader, BufferedReader, PrintWriter} +import util.Random + +/** Represents a page view on a website with associated dimension data.*/ +class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) { + override def toString() : String = { + "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) + } +} +object PageView { + def fromString(in : String) : PageView = { + val parts = in.split("\t") + new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) + } +} + +/** Generates streaming events to simulate page views on a website. + * + * This should be used in tandem with PageViewStream.scala. Example: + * $ ./run spark.streaming.examples.clickstream.PageViewGenerator 44444 10 + * $ ./run spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 + * */ +object PageViewGenerator { + val pages = Map("http://foo.com/" -> .7, + "http://foo.com/news" -> 0.2, + "http://foo.com/contact" -> .1) + val httpStatus = Map(200 -> .95, + 404 -> .05) + val userZipCode = Map(94709 -> .5, + 94117 -> .5) + val userID = Map((1 to 100).map(_ -> .01):_*) + + + def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { + val rand = new Random().nextDouble() + var total = 0.0 + for ((item, prob) <- inputMap) { + total = total + prob + if (total > rand) { + return item + } + } + return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 + } + + def getNextClickEvent() : String = { + val id = pickFromDistribution(userID) + val page = pickFromDistribution(pages) + val status = pickFromDistribution(httpStatus) + val zipCode = pickFromDistribution(userZipCode) + new PageView(page, status, zipCode, id).toString() + } + + def main(args : Array[String]) { + if (args.length != 2) { + System.err.println("Usage: PageViewGenerator ") + System.exit(1) + } + val port = args(0).toInt + val viewsPerSecond = args(1).toFloat + val sleepDelayMs = (1000.0 / viewsPerSecond).toInt + val listener = new ServerSocket(port) + println("Listening on port: " + port) + + while (true) { + val socket = listener.accept() + new Thread() { + override def run = { + println("Got client connected from: " + socket.getInetAddress) + val out = new PrintWriter(socket.getOutputStream(), true) + + while (true) { + Thread.sleep(sleepDelayMs) + out.write(getNextClickEvent()) + out.flush() + } + socket.close() + } + }.start() + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala new file mode 100644 index 0000000000..1a51fb66cd --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala @@ -0,0 +1,85 @@ +package spark.streaming.examples.clickstream + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ +import spark.SparkContext._ + +/** Analyses a streaming dataset of web page views. This class demonstrates several types of + * operators available in Spark streaming. + * + * This should be used in tandem with PageViewStream.scala. Example: + * $ ./run spark.streaming.examples.clickstream.PageViewGenerator 44444 10 + * $ ./run spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 + * */ +object PageViewStream { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println("Usage: PageViewStream ") + System.err.println(" must be one of pageCounts, slidingPageCounts," + + " errorRatePerZipCode, activeUserCount, popularUsersSeen") + System.exit(1) + } + val metric = args(0) + val host = args(1) + val port = args(2).toInt + + // Create the context and set the batch size + val ssc = new StreamingContext("local[2]", "PageViewStream") + ssc.setBatchDuration(Seconds(1)) + + // Create a NetworkInputDStream on target host:port and convert each line to a PageView + val pageViews = ssc.networkTextStream(host, port) + .flatMap(_.split("\n")) + .map(PageView.fromString(_)) + + // Return a count of views per URL seen in each batch + val pageCounts = pageViews.map(view => ((view.url, 1))).countByKey() + + // Return a sliding window of page views per URL in the last ten seconds + val slidingPageCounts = pageViews.map(view => ((view.url, 1))) + .window(Seconds(10), Seconds(2)) + .countByKey() + + + // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds + val statusesPerZipCode = pageViews.window(Seconds(30), Seconds(2)) + .map(view => ((view.zipCode, view.status))) + .groupByKey() + val errorRatePerZipCode = statusesPerZipCode.map{ + case(zip, statuses) => + val normalCount = statuses.filter(_ == 200).size + val errorCount = statuses.size - normalCount + val errorRatio = errorCount.toFloat / statuses.size + if (errorRatio > 0.05) {"%s: **%s**".format(zip, errorRatio)} + else {"%s: %s".format(zip, errorRatio)} + } + + // Return the number unique users in last 15 seconds + val activeUserCount = pageViews.window(Seconds(15), Seconds(2)) + .map(view => (view.userID, 1)) + .groupByKey() + .count() + .map("Unique active users: " + _) + + // An external dataset we want to join to this stream + val userList = ssc.sc.parallelize( + Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) + + metric match { + case "pageCounts" => pageCounts.print() + case "slidingPageCounts" => slidingPageCounts.print() + case "errorRatePerZipCode" => errorRatePerZipCode.print() + case "activeUserCount" => activeUserCount.print() + case "popularUsersSeen" => + // Look for users in our existing dataset and print it out if we have a match + pageViews.map(view => (view.userID, 1)) + .foreachRDD((rdd, time) => rdd.join(userList) + .map(_._2._2) + .take(10) + .foreach(u => println("Saw user %s at time %s".format(u, time)))) + case _ => println("Invalid metric entered: " + metric) + } + + ssc.start() + } +} -- cgit v1.2.3 From 10c1abcb6ac42b248818fa585a9ad49c2fa4851a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 17 Nov 2012 17:27:00 -0800 Subject: Fixed checkpointing bug in CoGroupedRDD. CoGroupSplits kept around the RDD splits of its parent RDDs, thus checkpointing its parents did not release the references to the parent splits. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 18 ++++++++++---- core/src/test/scala/spark/CheckpointSuite.scala | 28 ++++++++++++++++++++++ .../src/main/scala/spark/streaming/DStream.scala | 4 ++-- .../main/scala/spark/streaming/DStreamGraph.scala | 2 +- 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index a313ebcbe8..94ef1b56e8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -12,9 +12,20 @@ import spark.RDD import spark.ShuffleDependency import spark.SparkEnv import spark.Split +import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable -private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep +private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) + extends CoGroupSplitDep { + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + rdd.synchronized { + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() + } + } +} private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep private[spark] @@ -55,7 +66,6 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) @transient var splits_ : Array[Split] = { - val firstRdd = rdds.head val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => @@ -63,7 +73,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) case s: ShuffleDependency[_, _] => new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep case _ => - new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep + new NarrowCoGroupSplitDep(r, i): CoGroupSplitDep } }.toList) } @@ -82,7 +92,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => { // Read them from the parent for ((k, v) <- rdd.iterator(itsSplit)) { getSeq(k.asInstanceOf[K])(depNum) += v diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 57dc43ddac..8622ce92aa 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -6,6 +6,7 @@ import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} import spark.SparkContext._ import storage.StorageLevel import java.util.concurrent.Semaphore +import collection.mutable.ArrayBuffer class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { initLogging() @@ -92,6 +93,33 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + + // Special test to make sure that the CoGroupSplit of CoGroupedRDD do not + // hold on to the splits of its parent RDDs, as the splits of parent RDDs + // may change while checkpointing. Rather the splits of parent RDDs must + // be fetched at the time of serialization to ensure the latest splits to + // be sent along with the task. + + val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + + val ones = sc.parallelize(1 to 100, 1).map(x => (x,1)) + val reduced = ones.reduceByKey(_ + _) + val seqOfCogrouped = new ArrayBuffer[RDD[(Int, Int)]]() + seqOfCogrouped += reduced.cogroup(ones).mapValues[Int](add) + for(i <- 1 to 10) { + seqOfCogrouped += seqOfCogrouped.last.cogroup(ones).mapValues(add) + } + val finalCogrouped = seqOfCogrouped.last + val intermediateCogrouped = seqOfCogrouped(5) + + val bytesBeforeCheckpoint = Utils.serialize(finalCogrouped.splits) + intermediateCogrouped.checkpoint() + finalCogrouped.count() + sleep(intermediateCogrouped) + val bytesAfterCheckpoint = Utils.serialize(finalCogrouped.splits) + println("Before = " + bytesBeforeCheckpoint.size + ", after = " + bytesAfterCheckpoint.size) + assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size, + "CoGroupedSplits still holds on to the splits of its parent RDDs") } /** diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 76cdf8c464..13770aa8fd 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -226,11 +226,11 @@ extends Serializable with Logging { case Some(newRDD) => if (storageLevel != StorageLevel.NONE) { newRDD.persist(storageLevel) - logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time) + logInfo("Persisting RDD " + newRDD.id + " for time " + time + " to " + storageLevel + " at time " + time) } if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { newRDD.checkpoint() - logInfo("Marking RDD " + newRDD + " for time " + time + " for checkpointing at time " + time) + logInfo("Marking RDD " + newRDD.id + " for time " + time + " for checkpointing at time " + time) } generatedRDDs.put(time, newRDD) Some(newRDD) diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index 246522838a..bd8c033eab 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -105,7 +105,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { private[streaming] def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") - assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") + //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + " is very low") assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") } } -- cgit v1.2.3 From 6757ed6a40121ee97a15506af8717bb8d97cf1ec Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 19 Nov 2012 09:42:35 -0800 Subject: Comment out code for fault-tolerance. --- .../spark/streaming/input/KafkaInputDStream.scala | 35 +++++++++++----------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala index 318537532c..3685d6c666 100644 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -15,8 +15,10 @@ import spark.storage.StorageLevel // Key for a specific Kafka Partition: (broker, topic, group, part) case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) +// NOT USED - Originally intended for fault-tolerance // Metadata for a Kafka Stream that it sent to the Master case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) +// NOT USED - Originally intended for fault-tolerance // Checkpoint data specific to a KafkaInputDstream case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) @@ -45,9 +47,13 @@ class KafkaInputDStream[T: ClassManifest]( // Metadata that keeps track of which messages have already been consumed. var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() + + /* NOT USED - Originally intended for fault-tolerance + // In case of a failure, the offets for a particular timestamp will be restored. @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null + override protected[streaming] def addMetadata(metadata: Any) { metadata match { case x : KafkaInputDStreamMetadata => @@ -80,18 +86,11 @@ class KafkaInputDStream[T: ClassManifest]( restoredOffsets = x.savedOffsets logInfo("Restored KafkaDStream offsets: " + savedOffsets) } - } + } */ def createReceiver(): NetworkReceiver[T] = { - // We have restored from a checkpoint, use the restored offsets - if (restoredOffsets != null) { - new KafkaReceiver(id, host, port, groupId, topics, restoredOffsets, storageLevel) - .asInstanceOf[NetworkReceiver[T]] - } else { - new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) + new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) .asInstanceOf[NetworkReceiver[T]] - } - } } @@ -103,7 +102,7 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, val ZK_TIMEOUT = 10000 // Handles pushing data into the BlockManager - lazy protected val dataHandler = new KafkaDataHandler(this, storageLevel) + lazy protected val dataHandler = new DataHandler(this, storageLevel) // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset lazy val offsets = HashMap[KafkaPartitionKey, Long]() // Connection to Kafka @@ -181,13 +180,15 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, } } - class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) - extends DataHandler[Any](receiver, storageLevel) { + // NOT USED - Originally intended for fault-tolerance + // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) + // extends DataHandler[Any](receiver, storageLevel) { - override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { - // Creates a new Block with Kafka-specific Metadata - new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) - } + // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { + // // Creates a new Block with Kafka-specific Metadata + // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) + // } - } + // } + } -- cgit v1.2.3 From 5e2b0a3bf60dead1ac7946c9984b067c926c2904 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 19 Nov 2012 10:17:58 -0800 Subject: Added Kafka Wordcount producer --- .../spark/streaming/examples/KafkaWordCount.scala | 72 +++++++++++++++------- .../spark/streaming/input/KafkaInputDStream.scala | 5 +- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index c85ac8e984..12e3f49fe9 100644 --- a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -1,5 +1,9 @@ package spark.streaming.examples +import java.util.Properties +import kafka.message.Message +import kafka.producer.SyncProducerConfig +import kafka.producer._ import spark.streaming._ import spark.streaming.StreamingContext._ import spark.storage.StorageLevel @@ -8,33 +12,57 @@ import spark.streaming.util.RawTextHelper._ object KafkaWordCount { def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: KafkaWordCount ") + if (args.length < 6) { + System.err.println("Usage: KafkaWordCount ") System.exit(1) } - val ssc = args(3) match { - // Restore the stream from a checkpoint - case "true" => - new StreamingContext("work/checkpoint") - case _ => - val tmp = new StreamingContext(args(0), "KafkaWordCount") - - tmp.setBatchDuration(Seconds(2)) - tmp.checkpoint("work/checkpoint", Seconds(10)) - - val lines = tmp.kafkaStream[String](args(1), args(2).toInt, "test_group", Map("test" -> 1), - Map(KafkaPartitionKey(0,"test","test_group",0) -> 0l)) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) - - wordCounts.persist().checkpoint(Seconds(10)) - wordCounts.print() - - tmp - } + val Array(master, hostname, port, group, topics, numThreads) = args + + val ssc = new StreamingContext(master, "KafkaWordCount") + ssc.checkpoint("checkpoint") + ssc.setBatchDuration(Seconds(2)) + + val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap + val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) + wordCounts.print() + ssc.start() + } +} + +// Produces some random words between 1 and 100. +object KafkaWordCountProducer { + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: KafkaWordCountProducer ") + System.exit(1) + } + + val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args + // Zookeper connection properties + val props = new Properties() + props.put("zk.connect", hostname + ":" + port) + props.put("serializer.class", "kafka.serializer.StringEncoder") + + val config = new ProducerConfig(props) + val producer = new Producer[String, String](config) + + // Send some messages + while(true) { + val messages = (1 to messagesPerSec.toInt).map { messageNum => + (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ") + }.toArray + println(messages.mkString(",")) + val data = new ProducerData[String, String](topic, messages) + producer.send(data) + Thread.sleep(100) + } } + } diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala index 3685d6c666..7c642d4802 100644 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -171,8 +171,7 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, groupId, msgAndMetadata.topicInfo.partition.partId) val offset = msgAndMetadata.topicInfo.getConsumeOffset offsets.put(key, offset) - // TODO: Remove Logging - logInfo("Handled message: " + (key, offset).toString) + // logInfo("Handled message: " + (key, offset).toString) // Keep on handling messages true @@ -190,5 +189,5 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, // } // } - + } -- cgit v1.2.3 From c97ebf64377e853ab7c616a103869a4417f25954 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 19 Nov 2012 23:22:07 +0000 Subject: Fixed bug in the number of splits in RDD after checkpointing. Modified reduceByKeyAndWindow (naive) computation from window+reduceByKey to reduceByKey+window+reduceByKey. --- conf/streaming-env.sh.template | 2 +- core/src/main/scala/spark/RDD.scala | 3 ++- streaming/src/main/scala/spark/streaming/DStream.scala | 3 ++- streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala | 6 +++++- streaming/src/main/scala/spark/streaming/Scheduler.scala | 2 +- streaming/src/main/scala/spark/streaming/WindowedDStream.scala | 3 +++ 6 files changed, 14 insertions(+), 5 deletions(-) diff --git a/conf/streaming-env.sh.template b/conf/streaming-env.sh.template index 6b4094c515..1ea9ba5541 100755 --- a/conf/streaming-env.sh.template +++ b/conf/streaming-env.sh.template @@ -11,7 +11,7 @@ SPARK_JAVA_OPTS+=" -XX:+UseConcMarkSweepGC" -# Using of Kryo serialization can improve serialization performance +# Using Kryo serialization can improve serialization performance # and therefore the throughput of the Spark Streaming programs. However, # using Kryo serialization with custom classes may required you to # register the classes with Kryo. Refer to the Spark documentation diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 6af8c377b5..8af6c9bd6a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -222,12 +222,13 @@ abstract class RDD[T: ClassManifest]( rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString rdd.saveAsObjectFile(checkpointFile) rdd.synchronized { - rdd.checkpointRDD = context.objectFile[T](checkpointFile) + rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) rdd.checkpointRDDSplits = rdd.checkpointRDD.splits rdd.changeDependencies(rdd.checkpointRDD) rdd.shouldCheckpoint = false rdd.isCheckpointInProgress = false rdd.isCheckpointed = true + println("Done checkpointing RDD " + rdd.id + ", " + rdd) } } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 13770aa8fd..26d5ce9198 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -321,7 +321,8 @@ extends Serializable with Logging { } } } - logInfo("Updated checkpoint data for time " + currentTime) + logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.size + " checkpoints, " + + "[" + checkpointData.mkString(",") + "]") } /** diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index e09d27d34f..720e63bba0 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -4,6 +4,7 @@ import spark.streaming.StreamingContext._ import spark.{Manifests, RDD, Partitioner, HashPartitioner} import spark.SparkContext._ +import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer @@ -115,7 +116,10 @@ extends Serializable { slideTime: Time, partitioner: Partitioner ): DStream[(K, V)] = { - self.window(windowTime, slideTime).reduceByKey(ssc.sc.clean(reduceFunc), partitioner) + val cleanedReduceFunc = ssc.sc.clean(reduceFunc) + self.reduceByKey(cleanedReduceFunc, partitioner) + .window(windowTime, slideTime) + .reduceByKey(cleanedReduceFunc, partitioner) } // This method is the efficient sliding window reduce operation, diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index e2dca91179..014021be61 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -17,7 +17,7 @@ extends Logging { val graph = ssc.graph - val concurrentJobs = System.getProperty("spark.stream.concurrentJobs", "1").toInt + val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) val checkpointWriter = if (ssc.checkpointInterval != null && ssc.checkpointDir != null) { diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala index ce89a3f99b..e4d2a634f5 100644 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala @@ -2,6 +2,7 @@ package spark.streaming import spark.RDD import spark.rdd.UnionRDD +import spark.storage.StorageLevel class WindowedDStream[T: ClassManifest]( @@ -18,6 +19,8 @@ class WindowedDStream[T: ClassManifest]( throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + parent.persist(StorageLevel.MEMORY_ONLY_SER) + def windowTime: Time = _windowTime override def dependencies = List(parent) -- cgit v1.2.3 From fd11d23bb3a817dabd414bceddebc35ad731f626 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 19 Nov 2012 19:04:39 -0800 Subject: Modified StreamingContext API to make constructor accept the batch size (since it is always needed, Patrick's suggestion). Added description to DStream and StreamingContext. --- .../src/main/scala/spark/streaming/DStream.scala | 26 ++++++++++-- .../main/scala/spark/streaming/DStreamGraph.scala | 4 +- .../scala/spark/streaming/StreamingContext.scala | 49 +++++++++++++++------- .../scala/spark/streaming/examples/CountRaw.scala | 32 -------------- .../spark/streaming/examples/FileStream.scala | 7 ++-- .../examples/FileStreamWithCheckpoint.scala | 5 +-- .../scala/spark/streaming/examples/GrepRaw.scala | 7 ++-- .../spark/streaming/examples/QueueStream.scala | 10 ++--- .../streaming/examples/TopKWordCountRaw.scala | 7 ++-- .../spark/streaming/examples/WordCountHdfs.scala | 5 +-- .../streaming/examples/WordCountNetwork.scala | 6 +-- .../spark/streaming/examples/WordCountRaw.scala | 7 ++-- .../examples/clickstream/PageViewStream.scala | 5 +-- .../spark/streaming/BasicOperationsSuite.scala | 2 +- .../scala/spark/streaming/InputStreamsSuite.scala | 12 ++---- .../test/scala/spark/streaming/TestSuiteBase.scala | 6 +-- 16 files changed, 92 insertions(+), 98 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/examples/CountRaw.scala diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 26d5ce9198..8efda2074d 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -17,6 +17,26 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration +/** + * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous + * sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]] + * for more details on RDDs). DStreams can either be created from live data (such as, data from + * HDFS. Kafka or Flume) or it can be generated by transformation existing DStreams using operations + * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each + * DStream periodically generates a RDD, either from live data or by transforming the RDD generated + * by a parent DStream. + * + * This class contains the basic operations available on all DStreams, such as `map`, `filter` and + * `window`. In addition, [[spark.streaming.PairDStreamFunctions]] contains operations available + * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations + * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through + * implicit conversions when `spark.streaming.StreamingContext._` is imported. + * + * DStreams internally is characterized by a few basic properties: + * - A list of other DStreams that the DStream depends on + * - A time interval at which the DStream generates an RDD + * - A function that is used to generate an RDD after each time interval + */ abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -28,7 +48,7 @@ extends Serializable with Logging { * ---------------------------------------------- */ - // Time by which the window slides in this DStream + // Time interval at which the DStream generates an RDD def slideTime: Time // List of parent DStreams on which this DStream depends on @@ -186,12 +206,12 @@ extends Serializable with Logging { dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def setRememberDuration(duration: Time) { + protected[streaming] def remember(duration: Time) { if (duration != null && duration > rememberDuration) { rememberDuration = duration logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) } - dependencies.foreach(_.setRememberDuration(parentRememberDuration)) + dependencies.foreach(_.remember(parentRememberDuration)) } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index bd8c033eab..d0a9ade61d 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -22,7 +22,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) - outputStreams.foreach(_.setRememberDuration(rememberDuration)) + outputStreams.foreach(_.remember(rememberDuration)) outputStreams.foreach(_.validate) inputStreams.par.foreach(_.start()) } @@ -50,7 +50,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { batchDuration = duration } - private[streaming] def setRememberDuration(duration: Time) { + private[streaming] def remember(duration: Time) { this.synchronized { if (rememberDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 7a9a71f303..4a41f2f516 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -18,19 +18,39 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import java.util.UUID -final class StreamingContext ( +/** + * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic + * information (such as, cluster URL and job name) to internally create a SparkContext, it provides + * methods used to create DStream from various input sources. + */ +class StreamingContext private ( sc_ : SparkContext, - cp_ : Checkpoint + cp_ : Checkpoint, + batchDur_ : Time ) extends Logging { - def this(sparkContext: SparkContext) = this(sparkContext, null) - - def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = - this(new SparkContext(master, frameworkName, sparkHome, jars), null) + /** + * Creates a StreamingContext using an existing SparkContext. + * @param sparkContext Existing SparkContext + * @param batchDuration The time interval at which streaming data will be divided into batches + */ + def this(sparkContext: SparkContext, batchDuration: Time) = this(sparkContext, null, batchDuration) - def this(path: String) = this(null, CheckpointReader.read(path)) + /** + * Creates a StreamingContext by providing the details necessary for creating a new SparkContext. + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param frameworkName A name for your job, to display on the cluster web UI + * @param batchDuration The time interval at which streaming data will be divided into batches + */ + def this(master: String, frameworkName: String, batchDuration: Time) = + this(new SparkContext(master, frameworkName), null, batchDuration) - def this(cp_ : Checkpoint) = this(null, cp_) + /** + * Recreates the StreamingContext from a checkpoint file. + * @param path Path either to the directory that was specified as the checkpoint directory, or + * to the checkpoint file 'graph' or 'graph.bk'. + */ + def this(path: String) = this(null, CheckpointReader.read(path), null) initLogging() @@ -57,7 +77,10 @@ final class StreamingContext ( cp_.graph.restoreCheckpointData() cp_.graph } else { - new DStreamGraph() + assert(batchDur_ != null, "Batch duration for streaming context cannot be null") + val newGraph = new DStreamGraph() + newGraph.setBatchDuration(batchDur_) + newGraph } } @@ -77,12 +100,8 @@ final class StreamingContext ( private[streaming] var receiverJobThread: Thread = null private[streaming] var scheduler: Scheduler = null - def setBatchDuration(duration: Time) { - graph.setBatchDuration(duration) - } - - def setRememberDuration(duration: Time) { - graph.setRememberDuration(duration) + def remember(duration: Time) { + graph.remember(duration) } def checkpoint(dir: String, interval: Time = null) { diff --git a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala deleted file mode 100644 index d2fdabd659..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/CountRaw.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel -import spark.streaming._ -import spark.streaming.StreamingContext._ - -object CountRaw { - def main(args: Array[String]) { - if (args.length != 5) { - System.err.println("Usage: CountRaw ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - - // Create the context and set the batch size - val ssc = new StreamingContext(master, "CountRaw") - ssc.setBatchDuration(Milliseconds(batchMillis)) - - // Make sure some tasks have started on each node - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() - - val rawStreams = (1 to numStreams).map(_ => - ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray - val union = new UnionDStream(rawStreams) - union.map(_.length + 2).reduce(_ + _).foreachRDD(r => println("Byte count: " + r.collect().mkString)) - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala index d68611abd6..81938d30d4 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala @@ -14,10 +14,9 @@ object FileStream { System.exit(1) } - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "FileStream") - ssc.setBatchDuration(Seconds(2)) - + // Create the context + val ssc = new StreamingContext(args(0), "FileStream", Seconds(1)) + // Create the new directory val directory = new Path(args(1)) val fs = directory.getFileSystem(new Configuration()) diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala index 21a83c0fde..b7bc15a1d5 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -32,9 +32,8 @@ object FileStreamWithCheckpoint { if (!fs.exists(directory)) fs.mkdirs(directory) // Create new streaming context - val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") - ssc_.setBatchDuration(Seconds(1)) - ssc_.checkpoint(checkpointDir, Seconds(1)) + val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint", Seconds(1)) + ssc_.checkpoint(checkpointDir) // Setup the streaming computation val inputStream = ssc_.textFileStream(directory.toString) diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index ffbea6e55d..6cb2b4c042 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -16,9 +16,10 @@ object GrepRaw { val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - // Create the context and set the batch size - val ssc = new StreamingContext(master, "GrepRaw") - ssc.setBatchDuration(Milliseconds(batchMillis)) + // Create the context + val ssc = new StreamingContext(master, "GrepRaw", Milliseconds(batchMillis)) + + // Warm up the JVMs on master and slave for JIT compilation to kick in warmUp(ssc.sc) diff --git a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala index 2af51bad28..2a265d021d 100644 --- a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala +++ b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala @@ -1,9 +1,8 @@ package spark.streaming.examples import spark.RDD -import spark.streaming.StreamingContext +import spark.streaming.{Seconds, StreamingContext} import spark.streaming.StreamingContext._ -import spark.streaming.Seconds import scala.collection.mutable.SynchronizedQueue @@ -15,10 +14,9 @@ object QueueStream { System.exit(1) } - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "QueueStream") - ssc.setBatchDuration(Seconds(1)) - + // Create the context + val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1)) + // Create the queue through which RDDs can be pushed to // a QueueInputDStream val rddQueue = new SynchronizedQueue[RDD[Int]]() diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 0411bde1a7..fe4c2bf155 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -20,12 +20,11 @@ object TopKWordCountRaw { val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args val k = 10 - // Create the context, set the batch size and checkpoint directory. + // Create the context, and set the checkpoint directory. // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts // periodically to HDFS - val ssc = new StreamingContext(master, "TopKWordCountRaw") - ssc.setBatchDuration(Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) + val ssc = new StreamingContext(master, "TopKWordCountRaw", Seconds(1)) + ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) // Warm up the JVMs on master and slave for JIT compilation to kick in /*warmUp(ssc.sc)*/ diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala index 591cb141c3..867a8f42c4 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala @@ -10,9 +10,8 @@ object WordCountHdfs { System.exit(1) } - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "WordCountHdfs") - ssc.setBatchDuration(Seconds(2)) + // Create the context + val ssc = new StreamingContext(args(0), "WordCountHdfs", Seconds(2)) // Create the FileInputDStream on the directory and use the // stream to count words in new files created diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala index ba1bd1de7c..eadda60563 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala @@ -6,13 +6,13 @@ import spark.streaming.StreamingContext._ object WordCountNetwork { def main(args: Array[String]) { if (args.length < 2) { - System.err.println("Usage: WordCountNetwork ") + System.err.println("Usage: WordCountNetwork \n" + + "In local mode, should be 'local[n]' with n > 1") System.exit(1) } // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "WordCountNetwork") - ssc.setBatchDuration(Seconds(2)) + val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 571428c0fe..a29c81d437 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -19,12 +19,11 @@ object WordCountRaw { val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - // Create the context, set the batch size and checkpoint directory. + // Create the context, and set the checkpoint directory. // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts // periodically to HDFS - val ssc = new StreamingContext(master, "WordCountRaw") - ssc.setBatchDuration(Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) + val ssc = new StreamingContext(master, "WordCountRaw", Seconds(1)) + ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) // Warm up the JVMs on master and slave for JIT compilation to kick in warmUp(ssc.sc) diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala index 1a51fb66cd..68be6b7893 100644 --- a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala +++ b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala @@ -23,9 +23,8 @@ object PageViewStream { val host = args(1) val port = args(2).toInt - // Create the context and set the batch size - val ssc = new StreamingContext("local[2]", "PageViewStream") - ssc.setBatchDuration(Seconds(1)) + // Create the context + val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1)) // Create a NetworkInputDStream on target host:port and convert each line to a PageView val pageViews = ssc.networkTextStream(host, port) diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index d0aaac0f2e..dc38ef4912 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -175,7 +175,7 @@ class BasicOperationsSuite extends TestSuiteBase { } val ssc = setupStreams(input, operation _) - ssc.setRememberDuration(rememberDuration) + ssc.remember(rememberDuration) runStreams[(Int, Int)](ssc, input.size, input.size / 2) val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 3e99440226..e98c096725 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -40,8 +40,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { testServer.start() // Set up the streaming context and input streams - val ssc = new StreamingContext(master, framework) - ssc.setBatchDuration(batchDuration) + val ssc = new StreamingContext(master, framework, batchDuration) val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] val outputStream = new TestOutputStream(networkStream, outputBuffer) @@ -89,8 +88,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { testServer.start() // Set up the streaming context and input streams - var ssc = new StreamingContext(master, framework) - ssc.setBatchDuration(batchDuration) + var ssc = new StreamingContext(master, framework, batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) var outputStream = new TestOutputStream(networkStream, new ArrayBuffer[Seq[String]]) @@ -137,8 +135,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } // Set up the streaming context and input streams - val ssc = new StreamingContext(master, framework) - ssc.setBatchDuration(batchDuration) + val ssc = new StreamingContext(master, framework, batchDuration) val filestream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] def output = outputBuffer.flatMap(x => x) @@ -198,8 +195,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } // Set up the streaming context and input streams - var ssc = new StreamingContext(master, framework) - ssc.setBatchDuration(batchDuration) + var ssc = new StreamingContext(master, framework, batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) val filestream = ssc.textFileStream(testDir.toString) var outputStream = new TestOutputStream(filestream, new ArrayBuffer[Seq[String]]) diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index 5fb5cc504c..8cc2f8ccfc 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -76,8 +76,7 @@ trait TestSuiteBase extends FunSuite with Logging { ): StreamingContext = { // Create StreamingContext - val ssc = new StreamingContext(master, framework) - ssc.setBatchDuration(batchDuration) + val ssc = new StreamingContext(master, framework, batchDuration) if (checkpointDir != null) { ssc.checkpoint(checkpointDir, checkpointInterval) } @@ -98,8 +97,7 @@ trait TestSuiteBase extends FunSuite with Logging { ): StreamingContext = { // Create StreamingContext - val ssc = new StreamingContext(master, framework) - ssc.setBatchDuration(batchDuration) + val ssc = new StreamingContext(master, framework, batchDuration) if (checkpointDir != null) { ssc.checkpoint(checkpointDir, checkpointInterval) } -- cgit v1.2.3 From b18d70870a33a4783c6b3b787bef9b0eec30bce0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 27 Nov 2012 15:08:49 -0800 Subject: Modified bunch HashMaps in Spark to use TimeStampedHashMap and made various modules use CleanupTask to periodically clean up metadata. --- core/src/main/scala/spark/CacheTracker.scala | 6 +- core/src/main/scala/spark/MapOutputTracker.scala | 27 ++++--- .../main/scala/spark/scheduler/DAGScheduler.scala | 13 +++- .../scala/spark/scheduler/ShuffleMapTask.scala | 6 +- core/src/main/scala/spark/util/CleanupTask.scala | 31 ++++++++ .../main/scala/spark/util/TimeStampedHashMap.scala | 87 ++++++++++++++++++++++ .../scala/spark/streaming/StreamingContext.scala | 13 +++- 7 files changed, 165 insertions(+), 18 deletions(-) create mode 100644 core/src/main/scala/spark/util/CleanupTask.scala create mode 100644 core/src/main/scala/spark/util/TimeStampedHashMap.scala diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index c5db6ce63a..0ee59bee0f 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -14,6 +14,7 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel +import util.{CleanupTask, TimeStampedHashMap} private[spark] sealed trait CacheTrackerMessage @@ -30,7 +31,7 @@ private[spark] case object StopCacheTracker extends CacheTrackerMessage private[spark] class CacheTrackerActor extends Actor with Logging { // TODO: Should probably store (String, CacheType) tuples - private val locs = new HashMap[Int, Array[List[String]]] + private val locs = new TimeStampedHashMap[Int, Array[List[String]]] /** * A map from the slave's host name to its cache size. @@ -38,6 +39,8 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] + private val cleanupTask = new CleanupTask("CacheTracker", locs.cleanup) + private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) @@ -86,6 +89,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { case StopCacheTracker => logInfo("Stopping CacheTrackerActor") sender ! true + cleanupTask.cancel() context.stop(self) } } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 45441aa5e5..d0be1bb913 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -17,6 +17,7 @@ import scala.collection.mutable.HashSet import scheduler.MapStatus import spark.storage.BlockManagerId import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import util.{CleanupTask, TimeStampedHashMap} private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) @@ -43,7 +44,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea val timeout = 10.seconds - var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] + var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. @@ -52,7 +53,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea // Cache a serialized version of the output statuses for each shuffle to send them out faster var cacheGeneration = generation - val cachedSerializedStatuses = new HashMap[Int, Array[Byte]] + val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] var trackerActor: ActorRef = if (isMaster) { val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) @@ -63,6 +64,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea actorSystem.actorFor(url) } + val cleanupTask = new CleanupTask("MapOutputTracker", this.cleanup) + // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. def askTracker(message: Any): Any = { @@ -83,14 +86,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.get(shuffleId) != null) { + if (mapStatuses.get(shuffleId) != None) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - var array = mapStatuses.get(shuffleId) + var array = mapStatuses(shuffleId) array.synchronized { array(mapId) = status } @@ -107,7 +110,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = mapStatuses.get(shuffleId) + var array = mapStatuses(shuffleId) if (array != null) { array.synchronized { if (array(mapId).address == bmAddress) { @@ -125,7 +128,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { - val statuses = mapStatuses.get(shuffleId) + val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") fetching.synchronized { @@ -138,7 +141,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case e: InterruptedException => } } - return mapStatuses.get(shuffleId).map(status => + return mapStatuses(shuffleId).map(status => (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) } else { fetching += shuffleId @@ -164,9 +167,15 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } } + def cleanup(cleanupTime: Long) { + mapStatuses.cleanup(cleanupTime) + cachedSerializedStatuses.cleanup(cleanupTime) + } + def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() + cleanupTask.cancel() trackerActor = null } @@ -192,7 +201,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] + mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] generation = newGen } } @@ -210,7 +219,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case Some(bytes) => return bytes case None => - statuses = mapStatuses.get(shuffleId) + statuses = mapStatuses(shuffleId) generationGotten = generation } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index aaaed59c4a..3af877b817 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -14,6 +14,7 @@ import spark.partial.ApproximateEvaluator import spark.partial.PartialResult import spark.storage.BlockManagerMaster import spark.storage.BlockManagerId +import util.{CleanupTask, TimeStampedHashMap} /** * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for @@ -61,9 +62,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val nextStageId = new AtomicInteger(0) - val idToStage = new HashMap[Int, Stage] + val idToStage = new TimeStampedHashMap[Int, Stage] - val shuffleToMapStage = new HashMap[Int, Stage] + val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] var cacheLocs = new HashMap[Int, Array[List[String]]] @@ -83,6 +84,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] + val cleanupTask = new CleanupTask("DAGScheduler", this.cleanup) + // Start a thread to run the DAGScheduler event loop new Thread("DAGScheduler") { setDaemon(true) @@ -591,8 +594,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return Nil } + def cleanup(cleanupTime: Long) { + idToStage.cleanup(cleanupTime) + shuffleToMapStage.cleanup(cleanupTime) + } + def stop() { eventQueue.put(StopDAGScheduler) + cleanupTask.cancel() taskSched.stop() } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 60105c42b6..fbf618c906 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -14,17 +14,19 @@ import com.ning.compress.lzf.LZFOutputStream import spark._ import spark.storage._ +import util.{TimeStampedHashMap, CleanupTask} private[spark] object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new JHashMap[Int, Array[Byte]] + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + val cleanupTask = new CleanupTask("ShuffleMapTask", serializedInfoCache.cleanup) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { - val old = serializedInfoCache.get(stageId) + val old = serializedInfoCache.get(stageId).orNull if (old != null) { return old } else { diff --git a/core/src/main/scala/spark/util/CleanupTask.scala b/core/src/main/scala/spark/util/CleanupTask.scala new file mode 100644 index 0000000000..ccc28803e0 --- /dev/null +++ b/core/src/main/scala/spark/util/CleanupTask.scala @@ -0,0 +1,31 @@ +package spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import spark.Logging + +class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging { + val delayMins = System.getProperty("spark.cleanup.delay", "-100").toInt + val periodMins = System.getProperty("spark.cleanup.period", (delayMins / 10).toString).toInt + val timer = new Timer(name + " cleanup timer", true) + val task = new TimerTask { + def run() { + try { + if (delayMins > 0) { + + cleanupFunc(System.currentTimeMillis() - (delayMins * 60 * 1000)) + logInfo("Ran cleanup task for " + name) + } + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + if (periodMins > 0) { + timer.schedule(task, periodMins * 60 * 1000, periodMins * 60 * 1000) + } + + def cancel() { + timer.cancel() + } +} diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala new file mode 100644 index 0000000000..7a22b80a20 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -0,0 +1,87 @@ +package spark.util + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{HashMap, Map} +import java.util.concurrent.ConcurrentHashMap + +/** + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * time stamp along with each key-value pair. Key-value pairs that are older than a particular + * threshold time can them be removed using the cleanup method. This is intended to be a drop-in + * replacement of scala.collection.mutable.HashMap. + */ +class TimeStampedHashMap[A, B] extends Map[A, B]() { + val internalMap = new ConcurrentHashMap[A, (B, Long)]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null) Some(value._1) else None + } + + def iterator: Iterator[(A, B)] = { + val jIterator = internalMap.entrySet().iterator() + jIterator.map(kv => (kv.getKey, kv.getValue._1)) + } + + override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.put(kv._1, (kv._2, currentTime)) + newMap + } + + override def - (key: A): Map[A, B] = { + internalMap.remove(key) + this + } + + override def += (kv: (A, B)): this.type = { + internalMap.put(kv._1, (kv._2, currentTime)) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + val value = internalMap.get(key) + if (value == null) throw new NoSuchElementException() + value._1 + } + + override def filter(p: ((A, B)) => Boolean): Map[A, B] = { + internalMap.map(kv => (kv._1, kv._2._1)).filter(p) + } + + override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: ((A, B)) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue._1) + f(kv) + } + } + + def cleanup(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue._2 < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() + +} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 4a41f2f516..58123dc82c 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -43,7 +43,7 @@ class StreamingContext private ( * @param batchDuration The time interval at which streaming data will be divided into batches */ def this(master: String, frameworkName: String, batchDuration: Time) = - this(new SparkContext(master, frameworkName), null, batchDuration) + this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration) /** * Recreates the StreamingContext from a checkpoint file. @@ -214,11 +214,8 @@ class StreamingContext private ( "Checkpoint directory has been set, but the graph checkpointing interval has " + "not been set. Please use StreamingContext.checkpoint() to set the interval." ) - - } - /** * This function starts the execution of the streams. */ @@ -265,6 +262,14 @@ class StreamingContext private ( object StreamingContext { + + def createNewSparkContext(master: String, frameworkName: String): SparkContext = { + if (System.getProperty("spark.cleanup.delay", "-1").toInt < 0) { + System.setProperty("spark.cleanup.delay", "60") + } + new SparkContext(master, frameworkName) + } + implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } -- cgit v1.2.3 From d5e7aad039603a8a02d11f9ebda001422ca4c341 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 08:36:55 +0000 Subject: Bug fixes --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 11 ++++++++++- core/src/main/scala/spark/util/CleanupTask.scala | 17 +++++++++-------- .../main/scala/spark/streaming/StreamingContext.scala | 2 +- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 3af877b817..affacb43ca 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -78,7 +78,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage + val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits val activeJobs = new HashSet[ActiveJob] @@ -595,8 +595,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } def cleanup(cleanupTime: Long) { + var sizeBefore = idToStage.size idToStage.cleanup(cleanupTime) + logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) + + sizeBefore = shuffleToMapStage.size shuffleToMapStage.cleanup(cleanupTime) + logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) + + sizeBefore = pendingTasks.size + pendingTasks.cleanup(cleanupTime) + logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) } def stop() { diff --git a/core/src/main/scala/spark/util/CleanupTask.scala b/core/src/main/scala/spark/util/CleanupTask.scala index ccc28803e0..a4357c62c6 100644 --- a/core/src/main/scala/spark/util/CleanupTask.scala +++ b/core/src/main/scala/spark/util/CleanupTask.scala @@ -5,24 +5,25 @@ import java.util.{TimerTask, Timer} import spark.Logging class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delayMins = System.getProperty("spark.cleanup.delay", "-100").toInt - val periodMins = System.getProperty("spark.cleanup.period", (delayMins / 10).toString).toInt + val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val periodSeconds = math.max(10, delaySeconds / 10) val timer = new Timer(name + " cleanup timer", true) val task = new TimerTask { def run() { try { - if (delayMins > 0) { - - cleanupFunc(System.currentTimeMillis() - (delayMins * 60 * 1000)) + if (delaySeconds > 0) { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) logInfo("Ran cleanup task for " + name) - } + } } catch { case e: Exception => logError("Error running cleanup task for " + name, e) } } } - if (periodMins > 0) { - timer.schedule(task, periodMins * 60 * 1000, periodMins * 60 * 1000) + if (periodSeconds > 0) { + logInfo("Starting cleanup task for " + name + " with delay of " + delaySeconds + " seconds and " + + "period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) } def cancel() { diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 58123dc82c..90dd560752 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -264,7 +264,7 @@ class StreamingContext private ( object StreamingContext { def createNewSparkContext(master: String, frameworkName: String): SparkContext = { - if (System.getProperty("spark.cleanup.delay", "-1").toInt < 0) { + if (System.getProperty("spark.cleanup.delay", "-1").toDouble < 0) { System.setProperty("spark.cleanup.delay", "60") } new SparkContext(master, frameworkName) -- cgit v1.2.3 From e463ae492068d2922e1d50c051a87f8010953dff Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 14:05:01 -0800 Subject: Modified StorageLevel and BlockManagerId to cache common objects and use cached object while deserializing. --- .../main/scala/spark/storage/BlockManager.scala | 28 +------------ .../main/scala/spark/storage/BlockManagerId.scala | 48 ++++++++++++++++++++++ .../main/scala/spark/storage/StorageLevel.scala | 28 ++++++++++++- .../scala/spark/storage/BlockManagerSuite.scala | 26 ++++++++++++ 4 files changed, 101 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockManagerId.scala diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 70d6d8369d..e4aa9247a3 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -20,33 +20,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import sun.nio.ch.DirectBuffer -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { - def this() = this(null, 0) // For deserialization only - - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) - - override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) - } - - override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() - } - - override def toString = "BlockManagerId(" + ip + ", " + port + ")" - - override def hashCode = ip.hashCode * 41 + port - - override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false - } -} - - -private[spark] +private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala new file mode 100644 index 0000000000..4933cc6606 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -0,0 +1,48 @@ +package spark.storage + +import java.io.{IOException, ObjectOutput, ObjectInput, Externalizable} +import java.util.concurrent.ConcurrentHashMap + +private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { + def this() = this(null, 0) // For deserialization only + + def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + + override def writeExternal(out: ObjectOutput) { + out.writeUTF(ip) + out.writeInt(port) + } + + override def readExternal(in: ObjectInput) { + ip = in.readUTF() + port = in.readInt() + } + + @throws(classOf[IOException]) + private def readResolve(): Object = { + BlockManagerId.getCachedBlockManagerId(this) + } + + + override def toString = "BlockManagerId(" + ip + ", " + port + ")" + + override def hashCode = ip.hashCode * 41 + port + + override def equals(that: Any) = that match { + case id: BlockManagerId => port == id.port && ip == id.ip + case _ => false + } +} + +object BlockManagerId { + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + + def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { + if (blockManagerIdCache.containsKey(id)) { + blockManagerIdCache.get(id) + } else { + blockManagerIdCache.put(id, id) + id + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..eb88eb2759 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -1,6 +1,9 @@ package spark.storage -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput} +import collection.mutable +import util.Random +import collection.mutable.ArrayBuffer /** * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, @@ -17,7 +20,8 @@ class StorageLevel( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -27,6 +31,10 @@ class StorageLevel( override def clone(): StorageLevel = new StorageLevel( this.useDisk, this.useMemory, this.deserialized, this.replication) + override def hashCode(): Int = { + toInt * 41 + replication + } + override def equals(other: Any): Boolean = other match { case s: StorageLevel => s.useDisk == useDisk && @@ -66,6 +74,11 @@ class StorageLevel( replication = in.readByte() } + @throws(classOf[IOException]) + private def readResolve(): Object = { + StorageLevel.getCachedStorageLevel(this) + } + override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) } @@ -82,4 +95,15 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + + val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + + def getCachedStorageLevel(level: StorageLevel): StorageLevel = { + if (storageLevelCache.containsKey(level)) { + storageLevelCache.get(level) + } else { + storageLevelCache.put(level, level) + level + } + } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 0e78228134..a2d5e39859 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -57,6 +57,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } + test("StorageLevel object caching") { + val level1 = new StorageLevel(false, false, false, 3) + val level2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(level1) + val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(level2) + val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(level1_ === level1, "Deserialized level1 not same as original level1") + assert(level2_ === level2, "Deserialized level2 not same as original level1") + assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") + assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + } + + test("BlockManagerId object caching") { + val id1 = new StorageLevel(false, false, false, 3) + val id2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(id1) + val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(id2) + val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(id1_ === id1, "Deserialized id1 not same as original id1") + assert(id2_ === id2, "Deserialized id2 not same as original id1") + assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") + assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + } + test("master + 1 manager interaction") { store = new BlockManager(master, serializer, 2000) val a1 = new Array[Byte](400) -- cgit v1.2.3 From 9e9e9e1d898387a1996e4c57128bafadb5938a9b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 18:48:14 -0800 Subject: Renamed CleanupTask to MetadataCleaner. --- core/src/main/scala/spark/CacheTracker.scala | 6 ++-- core/src/main/scala/spark/MapOutputTracker.scala | 6 ++-- .../main/scala/spark/scheduler/DAGScheduler.scala | 6 ++-- .../scala/spark/scheduler/ShuffleMapTask.scala | 5 ++-- core/src/main/scala/spark/util/CleanupTask.scala | 32 ---------------------- .../main/scala/spark/util/MetadataCleaner.scala | 32 ++++++++++++++++++++++ 6 files changed, 44 insertions(+), 43 deletions(-) delete mode 100644 core/src/main/scala/spark/util/CleanupTask.scala create mode 100644 core/src/main/scala/spark/util/MetadataCleaner.scala diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 0ee59bee0f..9888f061d9 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -14,7 +14,7 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel -import util.{CleanupTask, TimeStampedHashMap} +import util.{MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait CacheTrackerMessage @@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - private val cleanupTask = new CleanupTask("CacheTracker", locs.cleanup) + private val metadataCleaner = new MetadataCleaner("CacheTracker", locs.cleanup) private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) @@ -89,7 +89,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { case StopCacheTracker => logInfo("Stopping CacheTrackerActor") sender ! true - cleanupTask.cancel() + metadataCleaner.cancel() context.stop(self) } } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index d0be1bb913..20ff5431af 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -17,7 +17,7 @@ import scala.collection.mutable.HashSet import scheduler.MapStatus import spark.storage.BlockManagerId import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import util.{CleanupTask, TimeStampedHashMap} +import util.{MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) @@ -64,7 +64,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea actorSystem.actorFor(url) } - val cleanupTask = new CleanupTask("MapOutputTracker", this.cleanup) + val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. @@ -175,7 +175,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() - cleanupTask.cancel() + metadataCleaner.cancel() trackerActor = null } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index affacb43ca..4b2570fa2b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -14,7 +14,7 @@ import spark.partial.ApproximateEvaluator import spark.partial.PartialResult import spark.storage.BlockManagerMaster import spark.storage.BlockManagerId -import util.{CleanupTask, TimeStampedHashMap} +import util.{MetadataCleaner, TimeStampedHashMap} /** * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for @@ -84,7 +84,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val cleanupTask = new CleanupTask("DAGScheduler", this.cleanup) + val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) // Start a thread to run the DAGScheduler event loop new Thread("DAGScheduler") { @@ -610,7 +610,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def stop() { eventQueue.put(StopDAGScheduler) - cleanupTask.cancel() + metadataCleaner.cancel() taskSched.stop() } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index fbf618c906..683f5ebec3 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -14,7 +14,7 @@ import com.ning.compress.lzf.LZFOutputStream import spark._ import spark.storage._ -import util.{TimeStampedHashMap, CleanupTask} +import util.{TimeStampedHashMap, MetadataCleaner} private[spark] object ShuffleMapTask { @@ -22,7 +22,8 @@ private[spark] object ShuffleMapTask { // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val cleanupTask = new CleanupTask("ShuffleMapTask", serializedInfoCache.cleanup) + + val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.cleanup) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/spark/util/CleanupTask.scala b/core/src/main/scala/spark/util/CleanupTask.scala deleted file mode 100644 index a4357c62c6..0000000000 --- a/core/src/main/scala/spark/util/CleanupTask.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.util - -import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} -import java.util.{TimerTask, Timer} -import spark.Logging - -class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt - val periodSeconds = math.max(10, delaySeconds / 10) - val timer = new Timer(name + " cleanup timer", true) - val task = new TimerTask { - def run() { - try { - if (delaySeconds > 0) { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran cleanup task for " + name) - } - } catch { - case e: Exception => logError("Error running cleanup task for " + name, e) - } - } - } - if (periodSeconds > 0) { - logInfo("Starting cleanup task for " + name + " with delay of " + delaySeconds + " seconds and " - + "period of " + periodSeconds + " secs") - timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) - } - - def cancel() { - timer.cancel() - } -} diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala new file mode 100644 index 0000000000..71ac39864e --- /dev/null +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -0,0 +1,32 @@ +package spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import spark.Logging + +class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { + val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val periodSeconds = math.max(10, delaySeconds / 10) + val timer = new Timer(name + " cleanup timer", true) + val task = new TimerTask { + def run() { + try { + if (delaySeconds > 0) { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) + } + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + if (periodSeconds > 0) { + logInfo("Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " + + "period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + } + + def cancel() { + timer.cancel() + } +} -- cgit v1.2.3 From c9789751bfc496d24e8369a0035d57f0ed8dcb58 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 23:18:24 -0800 Subject: Added metadata cleaner to BlockManager to remove old blocks completely. --- .../main/scala/spark/storage/BlockManager.scala | 47 ++++++++++++++++------ .../scala/spark/storage/BlockManagerMaster.scala | 1 + 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index e4aa9247a3..1e36578e1a 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -10,12 +10,12 @@ import java.nio.{MappedByteBuffer, ByteBuffer} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import scala.collection.JavaConversions._ import spark.{CacheTracker, Logging, SizeEstimator, SparkException, Utils} import spark.network._ import spark.serializer.Serializer -import spark.util.ByteBufferInputStream +import spark.util.{MetadataCleaner, TimeStampedHashMap, ByteBufferInputStream} + import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import sun.nio.ch.DirectBuffer @@ -51,7 +51,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000) + private val blockInfo = new TimeStampedHashMap[String, BlockInfo]() private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -80,6 +80,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val host = System.getProperty("spark.hostname", Utils.localHostName()) + val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) initialize() /** @@ -102,8 +103,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Get storage level of local block. If no info exists for the block, then returns null. */ def getLevel(blockId: String): StorageLevel = { - val info = blockInfo.get(blockId) - if (info != null) info.level else null + blockInfo.get(blockId).map(_.level).orNull } /** @@ -113,9 +113,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def reportBlockStatus(blockId: String) { val (curLevel, inMemSize, onDiskSize) = blockInfo.get(blockId) match { - case null => + case None => (StorageLevel.NONE, 0L, 0L) - case info => + case Some(info) => info.synchronized { info.level match { case null => @@ -173,7 +173,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -258,7 +258,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -517,7 +517,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId) + val oldBlock = blockInfo.get(blockId).orNull if (oldBlock != null) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") oldBlock.waitForReady() @@ -618,7 +618,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.containsKey(blockId)) { + if (blockInfo.contains(blockId)) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") return } @@ -740,7 +740,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { val level = info.level @@ -767,6 +767,29 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } + def dropOldBlocks(cleanupTime: Long) { + logInfo("Dropping blocks older than " + cleanupTime) + val iterator = blockInfo.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + if (time < cleanupTime) { + info.synchronized { + val level = info.level + if (level.useMemory) { + memoryStore.remove(id) + } + if (level.useDisk) { + diskStore.remove(id) + } + iterator.remove() + logInfo("Dropped block " + id) + } + reportBlockStatus(id) + } + } + } + def shouldCompress(blockId: String): Boolean = { if (blockId.startsWith("shuffle_")) { compressShuffle diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 397395a65b..af15663621 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -341,6 +341,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor throw new Exception("Self index for " + blockManagerId + " not found") } + // Note that this logic will select the same node multiple times if there aren't enough peers var index = selfIndex while (res.size < size) { index += 1 -- cgit v1.2.3 From 6fcd09f499dca66d255aa7196839156433aae442 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 29 Nov 2012 02:06:33 -0800 Subject: Added TimeStampedHashSet and used that to cleanup the list of registered RDD IDs in CacheTracker. --- core/src/main/scala/spark/CacheTracker.scala | 10 +++- .../main/scala/spark/util/TimeStampedHashMap.scala | 14 +++-- .../main/scala/spark/util/TimeStampedHashSet.scala | 66 ++++++++++++++++++++++ 3 files changed, 81 insertions(+), 9 deletions(-) create mode 100644 core/src/main/scala/spark/util/TimeStampedHashSet.scala diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 9888f061d9..cb54e12257 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -14,7 +14,7 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel -import util.{MetadataCleaner, TimeStampedHashMap} +import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait CacheTrackerMessage @@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - private val metadataCleaner = new MetadataCleaner("CacheTracker", locs.cleanup) + private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.cleanup) private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) @@ -113,11 +113,15 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b actorSystem.actorFor(url) } - val registeredRddIds = new HashSet[Int] + // TODO: Consider removing this HashSet completely as locs CacheTrackerActor already + // keeps track of registered RDDs + val registeredRddIds = new TimeStampedHashSet[Int] // Remembers which splits are currently being loaded (on worker nodes) val loading = new HashSet[String] + val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.cleanup) + // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. def askTracker(message: Any): Any = { diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 7a22b80a20..9bcc9245c0 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -1,7 +1,7 @@ package spark.util -import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, Map} +import scala.collection.JavaConversions +import scala.collection.mutable.Map import java.util.concurrent.ConcurrentHashMap /** @@ -20,7 +20,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { def iterator: Iterator[(A, B)] = { val jIterator = internalMap.entrySet().iterator() - jIterator.map(kv => (kv.getKey, kv.getValue._1)) + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) } override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { @@ -31,8 +31,10 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { } override def - (key: A): Map[A, B] = { - internalMap.remove(key) - this + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap } override def += (kv: (A, B)): this.type = { @@ -56,7 +58,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { } override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - internalMap.map(kv => (kv._1, kv._2._1)).filter(p) + JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) } override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala new file mode 100644 index 0000000000..539dd75844 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala @@ -0,0 +1,66 @@ +package spark.util + +import scala.collection.mutable.Set +import scala.collection.JavaConversions +import java.util.concurrent.ConcurrentHashMap + + +class TimeStampedHashSet[A] extends Set[A] { + val internalMap = new ConcurrentHashMap[A, Long]() + + def contains(key: A): Boolean = { + internalMap.contains(key) + } + + def iterator: Iterator[A] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(_.getKey) + } + + override def + (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet += elem + newSet + } + + override def - (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet -= elem + newSet + } + + override def += (key: A): this.type = { + internalMap.put(key, currentTime) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def empty: Set[A] = new TimeStampedHashSet[A]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: (A) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + f(iterator.next.getKey) + } + } + + def cleanup(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() +} -- cgit v1.2.3 From 62965c5d8e3f4f0246ac2c8814ac75ea82b3f238 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 1 Dec 2012 08:26:10 -0800 Subject: Added ssc.union --- streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala | 3 ++- streaming/src/main/scala/spark/streaming/StreamingContext.scala | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 8b484e6acf..bb852cbcca 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -118,7 +118,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( if (seqOfValues(0).isEmpty) { // If previous window's reduce value does not exist, then at least new values should exist if (newValues.isEmpty) { - throw new Exception("Neither previous window has value for key, nor new values found") + val info = "seqOfValues =\n" + seqOfValues.map(x => "[" + x.mkString(",") + "]").mkString("\n") + throw new Exception("Neither previous window has value for key, nor new values found\n" + info) } // Reduce the new values newValues.reduce(reduceF) // return diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 90dd560752..63d8766749 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -189,6 +189,10 @@ class StreamingContext private ( inputStream } + def union[T: ClassManifest](streams: Seq[DStream[T]]): DStream[T] = { + new UnionDStream[T](streams.toArray) + } + /** * This function registers a InputDStream as an input stream that will be * started (InputDStream.start() called) to get the input data streams. -- cgit v1.2.3 From 477de94894b7d8eeed281d33c12bcb2269d117c7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 1 Dec 2012 13:15:06 -0800 Subject: Minor modifications. --- core/src/main/scala/spark/util/MetadataCleaner.scala | 7 ++++++- streaming/src/main/scala/spark/streaming/DStream.scala | 15 ++++++++++++++- .../scala/spark/streaming/ReducedWindowedDStream.scala | 4 ++-- .../src/main/scala/spark/streaming/StreamingContext.scala | 8 ++++++-- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 71ac39864e..2541b26255 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -5,7 +5,7 @@ import java.util.{TimerTask, Timer} import spark.Logging class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val delaySeconds = MetadataCleaner.getDelaySeconds val periodSeconds = math.max(10, delaySeconds / 10) val timer = new Timer(name + " cleanup timer", true) val task = new TimerTask { @@ -30,3 +30,8 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging timer.cancel() } } + +object MetadataCleaner { + def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt + def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) } +} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 8efda2074d..28a3e2dfc7 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -146,6 +146,8 @@ extends Serializable with Logging { } protected[streaming] def validate() { + assert(rememberDuration != null, "Remember duration is set to null") + assert( !mustCheckpoint || checkpointInterval != null, "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + @@ -180,13 +182,24 @@ extends Serializable with Logging { checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." ) + val metadataCleanupDelay = System.getProperty("spark.cleanup.delay", "-1").toDouble + assert( + metadataCleanupDelay < 0 || rememberDuration < metadataCleanupDelay * 60 * 1000, + "It seems you are doing some DStream window operation or setting a checkpoint interval " + + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + + "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + + "delay is set to " + metadataCleanupDelay + " minutes, which is not sufficient. Please set " + + "the Java property 'spark.cleanup.delay' to more than " + + math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes." + ) + dependencies.foreach(_.validate()) logInfo("Slide time = " + slideTime) logInfo("Storage level = " + storageLevel) logInfo("Checkpoint interval = " + checkpointInterval) logInfo("Remember duration = " + rememberDuration) - logInfo("Initialized " + this) + logInfo("Initialized and validated " + this) } protected[streaming] def setContext(s: StreamingContext) { diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index bb852cbcca..f63a9e0011 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -118,8 +118,8 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( if (seqOfValues(0).isEmpty) { // If previous window's reduce value does not exist, then at least new values should exist if (newValues.isEmpty) { - val info = "seqOfValues =\n" + seqOfValues.map(x => "[" + x.mkString(",") + "]").mkString("\n") - throw new Exception("Neither previous window has value for key, nor new values found\n" + info) + throw new Exception("Neither previous window has value for key, nor new values found. " + + "Are you sure your key class hashes consistently?") } // Reduce the new values newValues.reduce(reduceF) // return diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 63d8766749..9c19f6588d 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -17,6 +17,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import java.util.UUID +import spark.util.MetadataCleaner /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -268,8 +269,11 @@ class StreamingContext private ( object StreamingContext { def createNewSparkContext(master: String, frameworkName: String): SparkContext = { - if (System.getProperty("spark.cleanup.delay", "-1").toDouble < 0) { - System.setProperty("spark.cleanup.delay", "60") + + // Set the default cleaner delay to an hour if not already set. + // This should be sufficient for even 1 second interval. + if (MetadataCleaner.getDelaySeconds < 0) { + MetadataCleaner.setDelaySeconds(60) } new SparkContext(master, frameworkName) } -- cgit v1.2.3 From b4dba55f78b0dfda728cf69c9c17e4863010d28d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 2 Dec 2012 02:03:05 +0000 Subject: Made RDD checkpoint not create a new thread. Fixed bug in detecting when spark.cleaner.delay is insufficient. --- core/src/main/scala/spark/RDD.scala | 31 +++++++--------------- .../main/scala/spark/util/TimeStampedHashMap.scala | 3 ++- .../src/main/scala/spark/streaming/DStream.scala | 9 ++++--- 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 8af6c9bd6a..fbfcfbd704 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -211,28 +211,17 @@ abstract class RDD[T: ClassManifest]( if (startCheckpoint) { val rdd = this - val env = SparkEnv.get - - // Spawn a new thread to do the checkpoint as it takes sometime to write the RDD to file - val th = new Thread() { - override def run() { - // Save the RDD to a file, create a new HadoopRDD from it, - // and change the dependencies from the original parents to the new RDD - SparkEnv.set(env) - rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString - rdd.saveAsObjectFile(checkpointFile) - rdd.synchronized { - rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) - rdd.checkpointRDDSplits = rdd.checkpointRDD.splits - rdd.changeDependencies(rdd.checkpointRDD) - rdd.shouldCheckpoint = false - rdd.isCheckpointInProgress = false - rdd.isCheckpointed = true - println("Done checkpointing RDD " + rdd.id + ", " + rdd) - } - } + rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString + rdd.saveAsObjectFile(checkpointFile) + rdd.synchronized { + rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) + rdd.checkpointRDDSplits = rdd.checkpointRDD.splits + rdd.changeDependencies(rdd.checkpointRDD) + rdd.shouldCheckpoint = false + rdd.isCheckpointInProgress = false + rdd.isCheckpointed = true + println("Done checkpointing RDD " + rdd.id + ", " + rdd + ", created RDD " + rdd.checkpointRDD.id + ", " + rdd.checkpointRDD) } - th.start() } else { // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked dependencies.foreach(_.rdd.doCheckpoint()) diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 9bcc9245c0..52f03784db 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -10,7 +10,7 @@ import java.util.concurrent.ConcurrentHashMap * threshold time can them be removed using the cleanup method. This is intended to be a drop-in * replacement of scala.collection.mutable.HashMap. */ -class TimeStampedHashMap[A, B] extends Map[A, B]() { +class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { val internalMap = new ConcurrentHashMap[A, (B, Long)]() def get(key: A): Option[B] = { @@ -79,6 +79,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() { while(iterator.hasNext) { val entry = iterator.next() if (entry.getValue._2 < threshTime) { + logDebug("Removing key " + entry.getKey) iterator.remove() } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 28a3e2dfc7..d2e9de110e 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -182,14 +182,15 @@ extends Serializable with Logging { checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." ) - val metadataCleanupDelay = System.getProperty("spark.cleanup.delay", "-1").toDouble + val metadataCleanerDelay = spark.util.MetadataCleaner.getDelaySeconds + logInfo("metadataCleanupDelay = " + metadataCleanerDelay) assert( - metadataCleanupDelay < 0 || rememberDuration < metadataCleanupDelay * 60 * 1000, + metadataCleanerDelay < 0 || rememberDuration < metadataCleanerDelay * 1000, "It seems you are doing some DStream window operation or setting a checkpoint interval " + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + - "delay is set to " + metadataCleanupDelay + " minutes, which is not sufficient. Please set " + - "the Java property 'spark.cleanup.delay' to more than " + + "delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " + + "the Java property 'spark.cleaner.delay' to more than " + math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes." ) -- cgit v1.2.3 From 609e00d599d3f429a838f598b3f32c5fdbd7ec5e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 2 Dec 2012 02:39:08 +0000 Subject: Minor mods --- streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index d0fef70f7e..ae6692290e 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -58,7 +58,7 @@ class NetworkInputTracker( throw new Exception("Register received for unexpected id " + streamId) } receiverInfo += ((streamId, receiverActor)) - logInfo("Registered receiver for network stream " + streamId) + logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) sender ! true } case AddBlocks(streamId, blockIds) => { -- cgit v1.2.3 From a69a82be2682148f5d1ebbdede15a47c90eea73d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 3 Dec 2012 22:37:31 -0800 Subject: Added metadata cleaner to HttpBroadcast to clean up old broacast files. --- .../main/scala/spark/broadcast/HttpBroadcast.scala | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 7eb4ddb74f..fef264aab1 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -11,6 +11,7 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark._ import spark.storage.StorageLevel +import util.{MetadataCleaner, TimeStampedHashSet} private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { @@ -64,6 +65,10 @@ private object HttpBroadcast extends Logging { private var serverUri: String = null private var server: HttpServer = null + private val files = new TimeStampedHashSet[String] + private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) + + def initialize(isMaster: Boolean) { synchronized { if (!initialized) { @@ -85,6 +90,7 @@ private object HttpBroadcast extends Logging { server = null } initialized = false + cleaner.cancel() } } @@ -108,6 +114,7 @@ private object HttpBroadcast extends Logging { val serOut = ser.serializeStream(out) serOut.writeObject(value) serOut.close() + files += file.getAbsolutePath } def read[T](id: Long): T = { @@ -123,4 +130,21 @@ private object HttpBroadcast extends Logging { serIn.close() obj } + + def cleanup(cleanupTime: Long) { + val iterator = files.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (file, time) = (entry.getKey, entry.getValue) + if (time < cleanupTime) { + try { + iterator.remove() + new File(file.toString).delete() + logInfo("Deleted broadcast file '" + file + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + } + } + } + } } -- cgit v1.2.3 From 21a08529768a5073bc5c15b6c2642ceef2acd0d5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 4 Dec 2012 22:10:25 -0800 Subject: Refactored RDD checkpointing to minimize extra fields in RDD class. --- core/src/main/scala/spark/RDD.scala | 149 ++++++++------------- core/src/main/scala/spark/RDDCheckpointData.scala | 68 ++++++++++ core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 10 +- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 2 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 2 - core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 2 - core/src/main/scala/spark/rdd/UnionRDD.scala | 12 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 2 +- core/src/test/scala/spark/CheckpointSuite.scala | 73 +--------- .../src/main/scala/spark/streaming/DStream.scala | 7 +- 12 files changed, 144 insertions(+), 194 deletions(-) create mode 100644 core/src/main/scala/spark/RDDCheckpointData.scala diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index fbfcfbd704..e9bd131e61 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -99,13 +99,7 @@ abstract class RDD[T: ClassManifest]( val partitioner: Option[Partitioner] = None /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(split) - } else { - Nil - } - } + def preferredLocations(split: Split): Seq[String] = Nil /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc @@ -118,6 +112,8 @@ abstract class RDD[T: ClassManifest]( // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE + protected[spark] val checkpointData = new RDDCheckpointData(this) + /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassManifest] = { dependencies.head.rdd.asInstanceOf[RDD[U]] @@ -126,17 +122,6 @@ abstract class RDD[T: ClassManifest]( /** Returns the `i` th parent RDD */ protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] - // Variables relating to checkpointing - protected val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - - protected var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing - protected var isCheckpointInProgress = false // set to true when checkpointing is in progress - protected[spark] var isCheckpointed = false // set to true after checkpointing is completed - - protected[spark] var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed - protected var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file - protected var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD - // Methods available on all RDDs: /** @@ -162,83 +147,14 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) Checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. - */ - protected[spark] def checkpoint() { - synchronized { - if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { - // do nothing - } else if (isCheckpointable) { - if (sc.checkpointDir == null) { - throw new Exception("Checkpoint directory has not been set in the SparkContext.") - } - shouldCheckpoint = true - } else { - throw new Exception(this + " cannot be checkpointed") - } - } - } - - def getCheckpointData(): Any = { - synchronized { - checkpointFile - } - } - - /** - * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job - * using this RDD has completed (therefore the RDD has been materialized and - * potentially stored in memory). In case this RDD is not marked for checkpointing, - * doCheckpoint() is called recursively on the parent RDDs. - */ - private[spark] def doCheckpoint() { - val startCheckpoint = synchronized { - if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) { - isCheckpointInProgress = true - true - } else { - false - } - } - - if (startCheckpoint) { - val rdd = this - rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString - rdd.saveAsObjectFile(checkpointFile) - rdd.synchronized { - rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size) - rdd.checkpointRDDSplits = rdd.checkpointRDD.splits - rdd.changeDependencies(rdd.checkpointRDD) - rdd.shouldCheckpoint = false - rdd.isCheckpointInProgress = false - rdd.isCheckpointed = true - println("Done checkpointing RDD " + rdd.id + ", " + rdd + ", created RDD " + rdd.checkpointRDD.id + ", " + rdd.checkpointRDD) - } + def getPreferredLocations(split: Split) = { + if (isCheckpointed) { + checkpointData.preferredLocations(split) } else { - // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked - dependencies.foreach(_.rdd.doCheckpoint()) + preferredLocations(split) } } - /** - * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] - * (`newRDD`) created from the checkpoint file. This method must ensure that all references - * to the original parent RDDs must be removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own changing - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. - */ - protected def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD)) - } - /** * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * This should ''not'' be called by users directly, but is available for implementors of custom @@ -247,7 +163,7 @@ abstract class RDD[T: ClassManifest]( final def iterator(split: Split): Iterator[T] = { if (isCheckpointed) { // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original - checkpointRDD.iterator(checkpointRDDSplits(split.index)) + checkpointData.iterator(split.index) } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { @@ -589,6 +505,55 @@ abstract class RDD[T: ClassManifest]( sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + def checkpoint() { + checkpointData.markForCheckpoint() + } + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed(): Boolean = { + checkpointData.isCheckpointed() + } + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile(): Option[String] = { + checkpointData.getCheckpointFile() + } + + /** + * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler + * after a job using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. + */ + protected[spark] def doCheckpoint() { + checkpointData.doCheckpoint() + dependencies.foreach(_.rdd.doCheckpoint()) + } + + /** + * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * (`newRDD`) created from the checkpoint file. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected[spark] def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD)) + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { synchronized { diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala new file mode 100644 index 0000000000..eb4482acee --- /dev/null +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -0,0 +1,68 @@ +package spark + +import org.apache.hadoop.fs.Path + + + +private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) +extends Serializable { + + class CheckpointState extends Serializable { + var state = 0 + + def mark() { if (state == 0) state = 1 } + def start() { assert(state == 1); state = 2 } + def finish() { assert(state == 2); state = 3 } + + def isMarked() = { state == 1 } + def isInProgress = { state == 2 } + def isCheckpointed = { state == 3 } + } + + val cpState = new CheckpointState() + var cpFile: Option[String] = None + var cpRDD: Option[RDD[T]] = None + var cpRDDSplits: Seq[Split] = Nil + + def markForCheckpoint() = { + rdd.synchronized { cpState.mark() } + } + + def isCheckpointed() = { + rdd.synchronized { cpState.isCheckpointed } + } + + def getCheckpointFile() = { + rdd.synchronized { cpFile } + } + + def doCheckpoint() { + rdd.synchronized { + if (cpState.isMarked && !cpState.isInProgress) { + cpState.start() + } else { + return + } + } + + val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + rdd.saveAsObjectFile(file) + val newRDD = rdd.context.objectFile[T](file, rdd.splits.size) + + rdd.synchronized { + rdd.changeDependencies(newRDD) + cpFile = Some(file) + cpRDD = Some(newRDD) + cpRDDSplits = newRDD.splits + cpState.finish() + } + } + + def preferredLocations(split: Split) = { + cpRDD.get.preferredLocations(split) + } + + def iterator(splitIndex: Int): Iterator[T] = { + cpRDD.get.iterator(cpRDDSplits(splitIndex)) + } +} diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index f4c3f99011..590f9eb738 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -41,12 +41,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(split) - } else { - locations_(split.asInstanceOf[BlockRDDSplit].blockId) - } - } + override def preferredLocations(split: Split) = + locations_(split.asInstanceOf[BlockRDDSplit].blockId) } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 458ad38d55..9bfc3f8ca3 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -32,12 +32,8 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def splits = splits_ override def preferredLocations(split: Split) = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(split) - } else { - val currSplit = split.asInstanceOf[CartesianSplit] - rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) - } + val currSplit = split.asInstanceOf[CartesianSplit] + rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } override def compute(split: Split) = { @@ -56,7 +52,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def dependencies = deps_ - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits rdd1 = null diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 94ef1b56e8..adfecea966 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -112,7 +112,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.iterator } - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits rdds = null diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 5b5f72ddeb..90c3b8bfd8 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -48,7 +48,7 @@ class CoalescedRDD[T: ClassManifest]( override def dependencies = deps_ - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits prev = null diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index 19ed56d9c0..a12531ea89 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -115,6 +115,4 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - - override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 2875abb2db..c12df5839e 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -93,6 +93,4 @@ class NewHadoopRDD[K, V]( val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } - - override val isCheckpointable = false } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 643a174160..30eb8483b6 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -49,15 +49,11 @@ class UnionRDD[T: ClassManifest]( override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = { - if (isCheckpointed) { - checkpointRDD.preferredLocations(s) - } else { - s.asInstanceOf[UnionSplit[T]].preferredLocations() - } - } + override def preferredLocations(s: Split): Seq[String] = + s.asInstanceOf[UnionSplit[T]].preferredLocations() + - override protected def changeDependencies(newRDD: RDD[_]) { + override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits rdds = null diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 4b2570fa2b..33d35b35d1 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -575,7 +575,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList + val rddPrefs = rdd.getPreferredLocations(rdd.splits(partition)).toList if (rddPrefs != Nil) { return rddPrefs } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 8622ce92aa..2cafef444c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -41,7 +41,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(parCollection.dependencies === Nil) val result = parCollection.collect() sleep(parCollection) // slightly extra time as loading classes for the first can take some time - assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result) + assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.collect() === result) } @@ -54,7 +54,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { blockRDD.checkpoint() val result = blockRDD.collect() sleep(blockRDD) - assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result) + assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.collect() === result) } @@ -122,35 +122,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { "CoGroupedSplits still holds on to the splits of its parent RDDs") } - /** - * This test forces two ResultTasks of the same job to be launched before and after - * the checkpointing of job's RDD is completed. - */ - test("Threading - ResultTasks") { - val op1 = (parCollection: RDD[Int]) => { - parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) - } - val op2 = (firstRDD: RDD[(Int, Int)]) => { - firstRDD.map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) - } - testThreading(op1, op2) - } - - /** - * This test forces two ShuffleMapTasks of the same job to be launched before and after - * the checkpointing of job's RDD is completed. - */ - test("Threading - ShuffleMapTasks") { - val op1 = (parCollection: RDD[Int]) => { - parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) }) - } - val op2 = (firstRDD: RDD[(Int, Int)]) => { - firstRDD.groupByKey(2).map(x => { println("2nd map running on " + x); Thread.sleep(500); x }) - } - testThreading(op1, op2) - } - - def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { val parCollection = sc.makeRDD(1 to 4, 4) val operatedRDD = op(parCollection) @@ -159,49 +130,11 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val result = operatedRDD.collect() sleep(operatedRDD) //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) - assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result) + assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) assert(operatedRDD.dependencies.head.rdd != parentRDD) assert(operatedRDD.collect() === result) } - def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { - - val parCollection = sc.makeRDD(1 to 2, 2) - - // This is the RDD that is to be checkpointed - val firstRDD = op1(parCollection) - val parentRDD = firstRDD.dependencies.head.rdd - firstRDD.checkpoint() - - // This the RDD that uses firstRDD. This is designed to launch a - // ShuffleMapTask that uses firstRDD. - val secondRDD = op2(firstRDD) - - // Starting first job, to initiate the checkpointing - logInfo("\nLaunching 1st job to initiate checkpointing\n") - firstRDD.collect() - - // Checkpointing has started but not completed yet - Thread.sleep(100) - assert(firstRDD.dependencies.head.rdd === parentRDD) - - // Starting second job; first task of this job will be - // launched _before_ firstRDD is marked as checkpointed - // and the second task will be launched _after_ firstRDD - // is marked as checkpointed - logInfo("\nLaunching 2nd job that is designed to launch tasks " + - "before and after checkpointing is complete\n") - val result = secondRDD.collect() - - // Check whether firstRDD has been successfully checkpointed - assert(firstRDD.dependencies.head.rdd != parentRDD) - - logInfo("\nRecomputing 2nd job to verify the results of the previous computation\n") - // Check whether the result in the previous job was correct or not - val correctResult = secondRDD.collect() - assert(result === correctResult) - } - def sleep(rdd: RDD[_]) { val startTime = System.currentTimeMillis() val maxWaitTime = 5000 diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d2e9de110e..d290c5927e 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -325,8 +325,9 @@ extends Serializable with Logging { logInfo("Updating checkpoint data for time " + currentTime) // Get the checkpointed RDDs from the generated RDDs - val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) - .map(x => (x._1, x._2.getCheckpointData())) + + val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + .map(x => (x._1, x._2.getCheckpointFile.get)) // Make a copy of the existing checkpoint data val oldCheckpointData = checkpointData.clone() @@ -373,7 +374,7 @@ extends Serializable with Logging { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") val rdd = ssc.sc.objectFile[T](data.toString) // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData() - rdd.checkpointFile = data.toString + rdd.checkpointData.cpFile = Some(data.toString) generatedRDDs += ((time, rdd)) } } -- cgit v1.2.3 From a23462191f9ad492d14f9efc3e915b1f522f543a Mon Sep 17 00:00:00 2001 From: Denny Date: Wed, 5 Dec 2012 10:30:40 -0800 Subject: Adjust Kafka code to work with new streaming changes. --- streaming/src/main/scala/spark/streaming/DStream.scala | 4 ++-- .../src/main/scala/spark/streaming/examples/KafkaWordCount.scala | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 85106b3ad8..792c129be8 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -365,8 +365,8 @@ extends Serializable with Logging { } } } - logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.size + " checkpoints, " - + "[" + checkpointData.mkString(",") + "]") + logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.rdds.size + " checkpoints, " + + "[" + checkpointData.rdds.mkString(",") + "]") } /** diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index 12e3f49fe9..fe55db6e2c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -4,6 +4,7 @@ import java.util.Properties import kafka.message.Message import kafka.producer.SyncProducerConfig import kafka.producer._ +import spark.SparkContext import spark.streaming._ import spark.streaming.StreamingContext._ import spark.storage.StorageLevel @@ -19,9 +20,9 @@ object KafkaWordCount { val Array(master, hostname, port, group, topics, numThreads) = args - val ssc = new StreamingContext(master, "KafkaWordCount") + val sc = new SparkContext(master, "KafkaWordCount") + val ssc = new StreamingContext(sc, Seconds(2)) ssc.checkpoint("checkpoint") - ssc.setBatchDuration(Seconds(2)) val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap) -- cgit v1.2.3 From 556c38ed91a405e0665897873e025e94971226af Mon Sep 17 00:00:00 2001 From: Denny Date: Wed, 5 Dec 2012 11:54:42 -0800 Subject: Added kafka JAR --- project/SparkBuild.scala | 2 +- streaming/lib/kafka-0.7.2.jar | Bin 0 -> 1358063 bytes 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 streaming/lib/kafka-0.7.2.jar diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f34736b1c4..6ef2ac477a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -158,7 +158,7 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", libraryDependencies ++= Seq( - "kafka" % "core-kafka_2.9.1" % "0.7.2") + "com.github.sgroschupf" % "zkclient" % "0.1") ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( diff --git a/streaming/lib/kafka-0.7.2.jar b/streaming/lib/kafka-0.7.2.jar new file mode 100644 index 0000000000..65f79925a4 Binary files /dev/null and b/streaming/lib/kafka-0.7.2.jar differ -- cgit v1.2.3 From 1f3a75ae9e518c003d84fa38a54583ecd841ffdc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 7 Dec 2012 13:45:52 -0800 Subject: Modified checkpoint testsuite to more comprehensively test checkpointing of various RDDs. Fixed checkpoint bug (splits referring to parent RDDs or parent splits) in UnionRDD and CoalescedRDD. Fixed bug in testing ShuffledRDD. Removed unnecessary and useless map-side combining step for narrow dependencies in CoGroupedRDD. Removed unncessary WeakReference stuff from many other RDDs. --- core/src/main/scala/spark/rdd/CartesianRDD.scala | 1 - core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 9 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 25 +- core/src/main/scala/spark/rdd/FilteredRDD.scala | 6 +- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 6 +- core/src/main/scala/spark/rdd/GlommedRDD.scala | 6 +- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 6 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 6 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 6 +- core/src/main/scala/spark/rdd/PipedRDD.scala | 10 +- core/src/main/scala/spark/rdd/SampledRDD.scala | 12 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 16 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 19 +- core/src/test/scala/spark/CheckpointSuite.scala | 267 +++++++++++++++++---- 14 files changed, 285 insertions(+), 110 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 9bfc3f8ca3..1d753a5168 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,7 +1,6 @@ package spark.rdd import spark._ -import java.lang.ref.WeakReference private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index adfecea966..57d472666b 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -17,6 +17,7 @@ import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) extends CoGroupSplitDep { + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { rdd.synchronized { @@ -50,12 +51,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { - val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) - if (mapSideCombinedRDD.partitioner == Some(part)) { - logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD) - deps += new OneToOneDependency(mapSideCombinedRDD) + if (rdd.partitioner == Some(part)) { + logInfo("Adding one-to-one dependency with " + rdd) + deps += new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) + val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 90c3b8bfd8..0b4499e2eb 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,9 +1,24 @@ package spark.rdd import spark._ -import java.lang.ref.WeakReference +import java.io.{ObjectOutputStream, IOException} -private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split +private[spark] case class CoalescedRDDSplit( + index: Int, + @transient rdd: RDD[_], + parentsIndices: Array[Int] + ) extends Split { + var parents: Seq[Split] = parentsIndices.map(rdd.splits(_)) + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + rdd.synchronized { + // Update the reference to parent split at the time of task serialization + parents = parentsIndices.map(rdd.splits(_)) + oos.defaultWriteObject() + } + } +} /** * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of @@ -21,12 +36,12 @@ class CoalescedRDD[T: ClassManifest]( @transient var splits_ : Array[Split] = { val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { - prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } + prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) } } else { (0 until maxPartitions).map { i => val rangeStart = (i * prevSplits.length) / maxPartitions val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions - new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd)) + new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray) }.toArray } } @@ -42,7 +57,7 @@ class CoalescedRDD[T: ClassManifest]( var deps_ : List[Dependency[_]] = List( new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = - splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) + splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices } ) diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index 1370cf6faf..02f2e7c246 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -1,15 +1,13 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class FilteredRDD[T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: T => Boolean) - extends RDD[T](prev.get) { + extends RDD[T](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).filter(f) diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 6b2cc67568..cdc8ecdcfe 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -1,15 +1,13 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: T => TraversableOnce[U]) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index 0f0b6ab0ff..df6f61c69d 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -1,13 +1,11 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] -class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]]) - extends RDD[Array[T]](prev.get) { +class GlommedRDD[T: ClassManifest](prev: RDD[T]) + extends RDD[Array[T]](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index b04f56cfcc..23b9fb023b 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -1,16 +1,14 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 7a4b6ffb03..41955c1d7a 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -1,9 +1,7 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -12,9 +10,9 @@ import java.lang.ref.WeakReference */ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: (Int, Iterator[T]) => Iterator[U]) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 8fa1872e0a..6f8cb21fd3 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -1,15 +1,13 @@ package spark.rdd -import spark.OneToOneDependency import spark.RDD import spark.Split -import java.lang.ref.WeakReference private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], f: T => U) - extends RDD[U](prev.get) { + extends RDD[U](prev) { override def splits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).map(f) diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d9293a9d1a..d2047375ea 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -8,11 +8,9 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source -import spark.OneToOneDependency import spark.RDD import spark.SparkEnv import spark.Split -import java.lang.ref.WeakReference /** @@ -20,16 +18,16 @@ import java.lang.ref.WeakReference * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](prev.get) { + extends RDD[String](prev) { - def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map()) + def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command)) + def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) override def splits = firstParent[T].splits diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index f273f257f8..c622e14a66 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -7,7 +7,6 @@ import cern.jet.random.engine.DRand import spark.RDD import spark.OneToOneDependency import spark.Split -import java.lang.ref.WeakReference private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -15,14 +14,14 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali } class SampledRDD[T: ClassManifest]( - prev: WeakReference[RDD[T]], + prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int) - extends RDD[T](prev.get) { + extends RDD[T](prev) { @transient - val splits_ = { + var splits_ : Array[Split] = { val rg = new Random(seed) firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } @@ -51,4 +50,9 @@ class SampledRDD[T: ClassManifest]( firstParent[T].iterator(split.prev).filter(x => (rand.nextDouble <= frac)) } } + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 31774585f4..a9dd3f35ed 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,11 +1,7 @@ package spark.rdd -import spark.Partitioner -import spark.RDD -import spark.ShuffleDependency -import spark.SparkEnv -import spark.Split -import java.lang.ref.WeakReference +import spark._ +import scala.Some private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx @@ -14,15 +10,15 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { /** * The resulting RDD from a shuffle (e.g. repartitioning of data). - * @param parent the parent RDD. + * @param prev the parent RDD. * @param part the partitioner used to partition the RDD * @tparam K the key class. * @tparam V the value class. */ class ShuffledRDD[K, V]( - @transient prev: WeakReference[RDD[(K, V)]], + prev: RDD[(K, V)], part: Partitioner) - extends RDD[(K, V)](prev.get.context, List(new ShuffleDependency(prev.get, part))) { + extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) { override val partitioner = Some(part) @@ -37,7 +33,7 @@ class ShuffledRDD[K, V]( } override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = Nil + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 30eb8483b6..a5948dd1f1 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -3,18 +3,28 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer import spark._ -import java.lang.ref.WeakReference +import java.io.{ObjectOutputStream, IOException} private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, + idx: Int, rdd: RDD[T], - split: Split) + splitIndex: Int, + var split: Split = null) extends Split with Serializable { def iterator() = rdd.iterator(split) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + rdd.synchronized { + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() + } + } } class UnionRDD[T: ClassManifest]( @@ -27,7 +37,7 @@ class UnionRDD[T: ClassManifest]( val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { - array(pos) = new UnionSplit(pos, rdd, split) + array(pos) = new UnionSplit(pos, rdd, split.index) pos += 1 } array @@ -52,7 +62,6 @@ class UnionRDD[T: ClassManifest]( override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() - override def changeDependencies(newRDD: RDD[_]) { deps_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 2cafef444c..51bd59e2b1 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -2,17 +2,16 @@ package spark import org.scalatest.{BeforeAndAfter, FunSuite} import java.io.File -import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD} +import spark.rdd._ import spark.SparkContext._ import storage.StorageLevel -import java.util.concurrent.Semaphore -import collection.mutable.ArrayBuffer class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { initLogging() var sc: SparkContext = _ var checkpointDir: File = _ + val partitioner = new HashPartitioner(2) before { checkpointDir = File.createTempFile("temp", "") @@ -40,7 +39,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() - sleep(parCollection) // slightly extra time as loading classes for the first can take some time assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.collect() === result) @@ -53,7 +51,6 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val blockRDD = new BlockRDD[String](sc, Array(blockId)) blockRDD.checkpoint() val result = blockRDD.collect() - sleep(blockRDD) assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.collect() === result) @@ -68,79 +65,247 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { testCheckpointing(_.mapPartitions(_.map(_.toString))) testCheckpointing(r => new MapPartitionsWithSplitRDD(r, (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) } test("ShuffledRDD") { - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _)) + // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD + testCheckpointing(rdd => { + new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner) + }) } test("UnionRDD") { - testCheckpointing(_.union(sc.makeRDD(5 to 6, 4))) + def otherRDD = sc.makeRDD(1 to 10, 4) + testCheckpointing(_.union(otherRDD), false, true) + testParentCheckpointing(_.union(otherRDD), false, true) } test("CartesianRDD") { - testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000) + def otherRDD = sc.makeRDD(1 to 10, 4) + testCheckpointing(_.cartesian(otherRDD)) + testParentCheckpointing(_.cartesian(otherRDD), true, false) } test("CoalescedRDD") { testCheckpointing(new CoalescedRDD(_, 2)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, + // so does not serialize the RDD (not need to check its size). + testParentCheckpointing(new CoalescedRDD(_, 2), true, false) + + // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after + // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. + // Note that this test is very specific to the current implementation of CoalescedRDDSplits + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint // checkpoint that MappedRDD + val coalesced = new CoalescedRDD(ones, 2) + val splitBeforeCheckpoint = + serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) + coalesced.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) + assert( + splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head, + "CoalescedRDDSplit.parents not updated after parent RDD checkpointed" + ) } test("CoGroupedRDD") { - val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1)) - testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2)) - testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2)) + // Test serialized size + // RDD with long lineage of one-to-one dependencies through cogroup transformations + val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD() + testCheckpointing(rdd1 => { + CheckpointSuite.cogroup(longLineageRDD1, rdd1.map(x => (x % 2, 1)), partitioner) + }, false, true) - // Special test to make sure that the CoGroupSplit of CoGroupedRDD do not - // hold on to the splits of its parent RDDs, as the splits of parent RDDs - // may change while checkpointing. Rather the splits of parent RDDs must - // be fetched at the time of serialization to ensure the latest splits to - // be sent along with the task. + val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD() + testParentCheckpointing(rdd1 => { + CheckpointSuite.cogroup(longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) + }, false, true) + } - val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + /** + * Test checkpointing of the final RDD generated by the given operation. By default, + * this method tests whether the size of serialized RDD has reduced after checkpointing or not. + * It can also test whether the size of serialized RDD splits has reduced after checkpointing or + * not, but this is not done by default as usually the splits do not refer to any RDD and + * therefore never store the lineage. + */ + def testCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean = true, + testRDDSplitSize: Boolean = false + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName - val ones = sc.parallelize(1 to 100, 1).map(x => (x,1)) - val reduced = ones.reduceByKey(_ + _) - val seqOfCogrouped = new ArrayBuffer[RDD[(Int, Int)]]() - seqOfCogrouped += reduced.cogroup(ones).mapValues[Int](add) - for(i <- 1 to 10) { - seqOfCogrouped += seqOfCogrouped.last.cogroup(ones).mapValues(add) - } - val finalCogrouped = seqOfCogrouped.last - val intermediateCogrouped = seqOfCogrouped(5) - - val bytesBeforeCheckpoint = Utils.serialize(finalCogrouped.splits) - intermediateCogrouped.checkpoint() - finalCogrouped.count() - sleep(intermediateCogrouped) - val bytesAfterCheckpoint = Utils.serialize(finalCogrouped.splits) - println("Before = " + bytesBeforeCheckpoint.size + ", after = " + bytesAfterCheckpoint.size) - assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size, - "CoGroupedSplits still holds on to the splits of its parent RDDs") - } - - def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { - val parCollection = sc.makeRDD(1 to 4, 4) - val operatedRDD = op(parCollection) + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) operatedRDD.checkpoint() - val parentRDD = operatedRDD.dependencies.head.rdd val result = operatedRDD.collect() - sleep(operatedRDD) - //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd ) + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the checkpoint file has been created assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + + // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the splits have been changed to the new Hadoop splits + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.cpRDDSplits.toList) + + // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced. If the RDD + // does not have any dependency to another RDD (e.g., ParallelCollection, + // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. + if (testRDDSize) { + println("Size of " + rddType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the splits has reduced. If the splits + // do not have any non-transient reference to another RDD or another RDD's splits, it + // does not refer to a lineage and therefore may not reduce in size after checkpointing. + // However, if the original splits before checkpointing do refer to a parent RDD, the splits + // must be forgotten after checkpointing (to remove all reference to parent RDDs) and + // replaced with the HadoopSplits of the checkpointed RDD. + if (testRDDSplitSize) { + println("Size of " + rddType + " splits " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " splits did not reduce after checkpointing " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) + } } - def sleep(rdd: RDD[_]) { - val startTime = System.currentTimeMillis() - val maxWaitTime = 5000 - while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) { - Thread.sleep(50) + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs splits. So even if the parent RDD is checkpointed and its splits changed, + * this RDD will remember the splits and therefore potentially the whole lineage. + */ + def testParentCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean, + testRDDSplitSize: Boolean + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.head.rdd + val rddType = operatedRDD.getClass.getSimpleName + val parentRDDType = parentRDD.getClass.getSimpleName + + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one + val result = operatedRDD.collect() + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the data in the checkpointed RDD is same as original + assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced because of its parent being + // checkpointed. If this RDD or its parent RDD do not have any dependency + // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may + // not reduce in size after checkpointing. + if (testRDDSize) { + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the splits has reduced because of its parent being + // checkpointed. If the splits do not have any non-transient reference to another RDD + // or another RDD's splits, it does not refer to a lineage and therefore may not reduce + // in size after checkpointing. However, if the splits do refer to the *splits* of a parent + // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's + // splits must have changed after checkpointing. + if (testRDDSplitSize) { + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) } - assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms") + + } + + /** + * Generate an RDD with a long lineage of one-to-one dependencies. + */ + def generateLongLineageRDD(): RDD[Int] = { + var rdd = sc.makeRDD(1 to 100, 4) + for (i <- 1 to 20) { + rdd = rdd.map(x => x) + } + rdd + } + + /** + * Generate an RDD with a long lineage specifically for CoGroupedRDD. + * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage + * and narrow dependency with this RDD. This method generate such an RDD by a sequence + * of cogroups and mapValues which creates a long lineage of narrow dependencies. + */ + def generateLongLineageRDDForCoGroupedRDD() = { + val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + + def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) + + var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones) + for(i <- 1 to 10) { + cogrouped = cogrouped.mapValues(add).cogroup(ones) + } + cogrouped.mapValues(add) + } + + /** + * Get serialized sizes of the RDD and its splits + */ + def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) } } + + +object CheckpointSuite { + // This is a custom cogroup function that does not use mapValues like + // the PairRDDFunctions.cogroup() + def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { + println("First = " + first + ", second = " + second) + new CoGroupedRDD[K]( + Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]), + part + ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] + } + +} \ No newline at end of file -- cgit v1.2.3 From c36ca10241991d46f2f1513b2c0c5e369d8b34f9 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 3 Nov 2012 15:19:41 -0700 Subject: Adding locality aware parallelize --- core/src/main/scala/spark/ParallelCollection.scala | 11 +++++++++-- core/src/main/scala/spark/SparkContext.scala | 10 +++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9725017b61..4bd9e1bd54 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -2,6 +2,7 @@ package spark import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer +import scala.collection.Map private[spark] class ParallelCollectionSplit[T: ClassManifest]( val rddId: Long, @@ -24,7 +25,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( private[spark] class ParallelCollection[T: ClassManifest]( @transient sc : SparkContext, @transient data: Seq[T], - numSlices: Int) + numSlices: Int, + locationPrefs : Map[Int,Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split @@ -40,7 +42,12 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - override def preferredLocations(s: Split): Seq[String] = Nil + override def preferredLocations(s: Split): Seq[String] = { + locationPrefs.get(splits_.indexOf(s)) match { + case Some(s) => s + case _ => Nil + } + } } private object ParallelCollection { diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7b46bee38..7ae1aea993 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -194,7 +194,7 @@ class SparkContext( /** Distribute a local Scala collection to form an RDD. */ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - new ParallelCollection[T](this, seq, numSlices) + new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]()) } /** Distribute a local Scala collection to form an RDD. */ @@ -202,6 +202,14 @@ class SparkContext( parallelize(seq, numSlices) } + /** Distribute a local Scala collection to form an RDD, with one or more + * location preferences for each object. Create a new partition for each + * collection item. */ + def makeLocalityConstrainedRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap + new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) + } + /** * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. -- cgit v1.2.3 From 3ff9710265d4bb518b89461cfb0fcc771e61a726 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 7 Dec 2012 15:01:15 -0800 Subject: Adding Flume InputDStream --- project/SparkBuild.scala | 3 +- .../scala/spark/streaming/FlumeInputDStream.scala | 130 +++++++++++++++++++++ .../spark/streaming/NetworkInputDStream.scala | 6 +- .../spark/streaming/NetworkInputTracker.scala | 13 ++- .../scala/spark/streaming/RawInputDStream.scala | 2 + .../scala/spark/streaming/SocketInputDStream.scala | 2 + .../scala/spark/streaming/StreamingContext.scala | 11 ++ .../spark/streaming/examples/FlumeEventCount.scala | 29 +++++ .../scala/spark/streaming/InputStreamsSuite.scala | 59 +++++++++- 9 files changed, 248 insertions(+), 7 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 6ef2ac477a..05f3c59681 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -91,7 +91,8 @@ object SparkBuild extends Build { "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011", "org.scalatest" %% "scalatest" % "1.6.1" % "test", "org.scalacheck" %% "scalacheck" % "1.9" % "test", - "com.novocode" % "junit-interface" % "0.8" % "test" + "com.novocode" % "junit-interface" % "0.8" % "test", + "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" ), parallelExecution := false, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala new file mode 100644 index 0000000000..9c403278c3 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala @@ -0,0 +1,130 @@ +package spark.streaming + +import java.io.{ObjectInput, ObjectOutput, Externalizable} +import spark.storage.StorageLevel +import org.apache.flume.source.avro.AvroSourceProtocol +import org.apache.flume.source.avro.AvroFlumeEvent +import org.apache.flume.source.avro.Status +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.avro.ipc.NettyServer +import java.net.InetSocketAddress +import collection.JavaConversions._ +import spark.Utils +import java.nio.ByteBuffer + +class FlumeInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel +) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { + + override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { + new FlumeReceiver(id, host, port, storageLevel) + } +} + +/** + * A wrapper class for AvroFlumeEvent's with a custom serialization format. + * + * This is necessary because AvroFlumeEvent uses inner data structures + * which are not serializable. + */ +class SparkFlumeEvent() extends Externalizable { + var event : AvroFlumeEvent = new AvroFlumeEvent() + + /* De-serialize from bytes. */ + def readExternal(in: ObjectInput) { + val bodyLength = in.readInt() + val bodyBuff = new Array[Byte](bodyLength) + in.read(bodyBuff) + + val numHeaders = in.readInt() + val headers = new java.util.HashMap[CharSequence, CharSequence] + + for (i <- 0 until numHeaders) { + val keyLength = in.readInt() + val keyBuff = new Array[Byte](keyLength) + in.read(keyBuff) + val key : String = Utils.deserialize(keyBuff) + + val valLength = in.readInt() + val valBuff = new Array[Byte](valLength) + in.read(valBuff) + val value : String = Utils.deserialize(valBuff) + + headers.put(key, value) + } + + event.setBody(ByteBuffer.wrap(bodyBuff)) + event.setHeaders(headers) + } + + /* Serialize to bytes. */ + def writeExternal(out: ObjectOutput) { + val body = event.getBody.array() + out.writeInt(body.length) + out.write(body) + + val numHeaders = event.getHeaders.size() + out.writeInt(numHeaders) + for ((k, v) <- event.getHeaders) { + val keyBuff = Utils.serialize(k.toString) + out.writeInt(keyBuff.length) + out.write(keyBuff) + val valBuff = Utils.serialize(v.toString) + out.writeInt(valBuff.length) + out.write(valBuff) + } + } +} + +object SparkFlumeEvent { + def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { + val event = new SparkFlumeEvent + event.event = in + event + } +} + +/** A simple server that implements Flume's Avro protocol. */ +class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { + override def append(event : AvroFlumeEvent) : Status = { + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) + Status.OK + } + + override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { + events.foreach (event => + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)) + Status.OK + } +} + +/** A NetworkReceiver which listens for events using the + * Flume Avro interface.*/ +class FlumeReceiver( + streamId: Int, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkReceiver[SparkFlumeEvent](streamId) { + + lazy val dataHandler = new DataHandler(this, storageLevel) + + protected override def onStart() { + val responder = new SpecificResponder( + classOf[AvroSourceProtocol], new FlumeEventServer(this)); + val server = new NettyServer(responder, new InetSocketAddress(host, port)); + dataHandler.start() + server.start() + logInfo("Flume receiver started") + } + + protected override def onStop() { + dataHandler.stop() + logInfo("Flume receiver stopped") + } + + override def getLocationConstraint = Some(host) +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index d3f37b8b0e..052fc8bb74 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -8,7 +8,6 @@ import spark.streaming.util.{RecurringTimer, SystemClock} import spark.storage.StorageLevel import java.nio.ByteBuffer -import java.util.concurrent.ArrayBlockingQueue import akka.actor.{Props, Actor} import akka.pattern.ask @@ -63,6 +62,9 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri /** This method will be called to stop receiving data. */ protected def onStop() + /** This method conveys a placement constraint (hostname) for this receiver. */ + def getLocationConstraint() : Option[String] = None + /** * This method starts the receiver. First is accesses all the lazy members to * materialize them. Then it calls the user-defined onStart() method to start @@ -151,6 +153,4 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri tracker ! DeregisterReceiver(streamId, msg) } } - } - diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 73ba877085..56661c2615 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -98,7 +98,18 @@ class NetworkInputTracker( def startReceivers() { val receivers = networkInputStreams.map(_.createReceiver()) - val tempRDD = ssc.sc.makeRDD(receivers, receivers.size) + + // We only honor constraints if all receivers have them + val hasLocationConstraints = receivers.map(_.getLocationConstraint().isDefined).reduce(_ && _) + + val tempRDD = + if (hasLocationConstraints) { + val receiversWithConstraints = receivers.map(r => (r, Seq(r.getLocationConstraint().toString))) + ssc.sc.makeLocalityConstrainedRDD[NetworkReceiver[_]](receiversWithConstraints) + } + else { + ssc.sc.makeRDD(receivers, receivers.size) + } val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { if (!iterator.hasNext) { diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index d5db8e787d..fd51ed47a5 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -31,6 +31,8 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S var blockPushingThread: Thread = null + override def getLocationConstraint = None + def onStart() { // Open a socket to the target address and keep reading from it logInfo("Connecting to " + host + ":" + port) diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index ff99d50b76..ebbb17a39a 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -34,6 +34,8 @@ class SocketReceiver[T: ClassManifest]( lazy protected val dataHandler = new DataHandler(this, storageLevel) + override def getLocationConstraint = None + protected def onStart() { logInfo("Connecting to " + host + ":" + port) val socket = new Socket(host, port) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 8153dd4567..ce47bcb2da 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -15,6 +15,7 @@ import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.flume.source.avro.AvroFlumeEvent import org.apache.hadoop.fs.Path import java.util.UUID import spark.util.MetadataCleaner @@ -166,6 +167,16 @@ class StreamingContext private ( inputStream } + def flumeStream ( + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2): DStream[SparkFlumeEvent] = { + val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel) + graph.addInputStream(inputStream) + inputStream + } + + def rawNetworkStream[T: ClassManifest]( hostname: String, port: Int, diff --git a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala new file mode 100644 index 0000000000..d76c92fdd5 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala @@ -0,0 +1,29 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ + +object FlumeEventCount { + def main(args: Array[String]) { + if (args.length != 4) { + System.err.println( + "Usage: FlumeEventCount ") + System.exit(1) + } + + val Array(master, host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "FlumeEventCount", + Milliseconds(batchMillis)) + + // Create a flume stream + val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) + + // Print out the count of events received from this server in each batch + stream.count().map(cnt => "Received " + cnt + " flume events." ).print() + + ssc.start() + } +} diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index e98c096725..ed9a659092 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -1,6 +1,6 @@ package spark.streaming -import java.net.{SocketException, Socket, ServerSocket} +import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.io.{File, BufferedWriter, OutputStreamWriter} import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} import collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -10,7 +10,14 @@ import spark.Logging import scala.util.Random import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter - +import org.apache.flume.source.avro.AvroSourceProtocol +import org.apache.flume.source.avro.AvroFlumeEvent +import org.apache.flume.source.avro.Status +import org.apache.avro.ipc.{specific, NettyTransceiver} +import org.apache.avro.ipc.specific.SpecificRequestor +import java.nio.ByteBuffer +import collection.JavaConversions._ +import java.nio.charset.Charset class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -123,6 +130,54 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { ssc.stop() } + test("flume input stream") { + // Set up the streaming context and input streams + val ssc = new StreamingContext(master, framework, batchDuration) + val flumeStream = ssc.flumeStream("localhost", 33333, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3, 4, 5) + + val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", 33333)); + val client = SpecificRequestor.getClient( + classOf[AvroSourceProtocol], transceiver); + + for (i <- 0 until input.size) { + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(input(i).toString.getBytes())) + event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + client.append(event) + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + } + + val startTime = System.currentTimeMillis() + while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size) + Thread.sleep(100) + } + Thread.sleep(1000) + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + logInfo("Stopping context") + ssc.stop() + + val decoder = Charset.forName("UTF-8").newDecoder() + + assert(outputBuffer.size === input.length) + for (i <- 0 until outputBuffer.size) { + assert(outputBuffer(i).size === 1) + val str = decoder.decode(outputBuffer(i).head.event.getBody) + assert(str.toString === input(i).toString) + assert(outputBuffer(i).head.event.getHeaders.get("test") === "header") + } + } + test("file input stream") { // Create a temporary directory -- cgit v1.2.3 From 3e796bdd57297134ed40b20d7692cd9c8cd6efba Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 7 Dec 2012 19:16:35 -0800 Subject: Changes in response to TD's review. --- core/src/main/scala/spark/SparkContext.scala | 6 +++--- .../scala/spark/streaming/FlumeInputDStream.scala | 2 +- .../spark/streaming/NetworkInputDStream.scala | 4 ++-- .../spark/streaming/NetworkInputTracker.scala | 10 ++++----- .../scala/spark/streaming/RawInputDStream.scala | 2 +- .../scala/spark/streaming/SocketInputDStream.scala | 2 +- .../spark/streaming/examples/FlumeEventCount.scala | 24 +++++++++++++++++----- 7 files changed, 32 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7ae1aea993..3ccdbfe10e 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -203,9 +203,9 @@ class SparkContext( } /** Distribute a local Scala collection to form an RDD, with one or more - * location preferences for each object. Create a new partition for each - * collection item. */ - def makeLocalityConstrainedRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + * location preferences (hostnames of Spark nodes) for each object. + * Create a new partition for each collection item. */ + def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) } diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala index 9c403278c3..2959ce4540 100644 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala @@ -126,5 +126,5 @@ class FlumeReceiver( logInfo("Flume receiver stopped") } - override def getLocationConstraint = Some(host) + override def getLocationPreference = Some(host) } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index 052fc8bb74..4e4e9fc942 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -62,8 +62,8 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri /** This method will be called to stop receiving data. */ protected def onStop() - /** This method conveys a placement constraint (hostname) for this receiver. */ - def getLocationConstraint() : Option[String] = None + /** This method conveys a placement preference (hostname) for this receiver. */ + def getLocationPreference() : Option[String] = None /** * This method starts the receiver. First is accesses all the lazy members to diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 56661c2615..b421f795ee 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -99,13 +99,13 @@ class NetworkInputTracker( def startReceivers() { val receivers = networkInputStreams.map(_.createReceiver()) - // We only honor constraints if all receivers have them - val hasLocationConstraints = receivers.map(_.getLocationConstraint().isDefined).reduce(_ && _) + // Right now, we only honor preferences if all receivers have them + val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _) val tempRDD = - if (hasLocationConstraints) { - val receiversWithConstraints = receivers.map(r => (r, Seq(r.getLocationConstraint().toString))) - ssc.sc.makeLocalityConstrainedRDD[NetworkReceiver[_]](receiversWithConstraints) + if (hasLocationPreferences) { + val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString))) + ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences) } else { ssc.sc.makeRDD(receivers, receivers.size) diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index fd51ed47a5..6acaa9aab1 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -31,7 +31,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S var blockPushingThread: Thread = null - override def getLocationConstraint = None + override def getLocationPreference = None def onStart() { // Open a socket to the target address and keep reading from it diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index ebbb17a39a..a9e37c0ff0 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -34,7 +34,7 @@ class SocketReceiver[T: ClassManifest]( lazy protected val dataHandler = new DataHandler(this, storageLevel) - override def getLocationConstraint = None + override def getLocationPreference = None protected def onStart() { logInfo("Connecting to " + host + ":" + port) diff --git a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala index d76c92fdd5..e60ce483a3 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala @@ -4,19 +4,33 @@ import spark.util.IntParam import spark.storage.StorageLevel import spark.streaming._ +/** + * Produce a streaming count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: FlumeEventCount + * + * is a Spark master URL + * is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * is the port the Flume receiver will listen on. + */ object FlumeEventCount { def main(args: Array[String]) { - if (args.length != 4) { + if (args.length != 3) { System.err.println( - "Usage: FlumeEventCount ") + "Usage: FlumeEventCount ") System.exit(1) } - val Array(master, host, IntParam(port), IntParam(batchMillis)) = args + val Array(master, host, IntParam(port)) = args + val batchInterval = Milliseconds(2000) // Create the context and set the batch size - val ssc = new StreamingContext(master, "FlumeEventCount", - Milliseconds(batchMillis)) + val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval) // Create a flume stream val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) -- cgit v1.2.3 From e42721601898ff199ca1c6cfeae159ad3ef691e3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 8 Dec 2012 12:46:59 -0800 Subject: Removed unnecessary testcases. --- core/src/test/scala/spark/CheckpointSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 8622ce92aa..41d84cb01c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -121,7 +121,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size, "CoGroupedSplits still holds on to the splits of its parent RDDs") } - + /* /** * This test forces two ResultTasks of the same job to be launched before and after * the checkpointing of job's RDD is completed. @@ -149,7 +149,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { } testThreading(op1, op2) } - + */ def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) { val parCollection = sc.makeRDD(1 to 4, 4) @@ -163,7 +163,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) assert(operatedRDD.collect() === result) } - + /* def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) { val parCollection = sc.makeRDD(1 to 2, 2) @@ -201,7 +201,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val correctResult = secondRDD.collect() assert(result === correctResult) } - + */ def sleep(rdd: RDD[_]) { val startTime = System.currentTimeMillis() val maxWaitTime = 5000 -- cgit v1.2.3 From 746afc2e6513d5f32f261ec0dbf2823f78a5e960 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 10 Dec 2012 23:36:37 -0800 Subject: Bunch of bug fixes related to checkpointing in RDDs. RDDCheckpointData object is used to lock all serialization and dependency changes for checkpointing. ResultTask converted to Externalizable and serialized RDD is cached like ShuffleMapTask. --- core/src/main/scala/spark/ParallelCollection.scala | 10 +- core/src/main/scala/spark/RDD.scala | 10 +- core/src/main/scala/spark/RDDCheckpointData.scala | 76 +++++++-- core/src/main/scala/spark/SparkContext.scala | 5 +- core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 21 ++- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 18 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 8 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 + core/src/main/scala/spark/rdd/UnionRDD.scala | 15 +- .../main/scala/spark/scheduler/ResultTask.scala | 95 ++++++++++- .../scala/spark/scheduler/ShuffleMapTask.scala | 21 ++- core/src/test/scala/spark/CheckpointSuite.scala | 187 ++++++++++++++++++--- 13 files changed, 389 insertions(+), 90 deletions(-) diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9725017b61..9d12af6912 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -28,10 +28,11 @@ private[spark] class ParallelCollection[T: ClassManifest]( extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split - // instead. UPDATE: With the new changes to enable checkpointing, this an be done. + // instead. + // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. @transient - val splits_ = { + var splits_ : Array[Split] = { val slices = ParallelCollection.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } @@ -41,6 +42,11 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def preferredLocations(s: Split): Seq[String] = Nil + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + } } private object ParallelCollection { diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e9bd131e61..efa03d5185 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -163,7 +163,7 @@ abstract class RDD[T: ClassManifest]( final def iterator(split: Split): Iterator[T] = { if (isCheckpointed) { // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original - checkpointData.iterator(split.index) + checkpointData.iterator(split) } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { @@ -556,16 +556,12 @@ abstract class RDD[T: ClassManifest]( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - synchronized { - oos.defaultWriteObject() - } + oos.defaultWriteObject() } @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { - synchronized { - ois.defaultReadObject() - } + ois.defaultReadObject() } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index eb4482acee..ff2ed4cdfc 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -1,12 +1,20 @@ package spark import org.apache.hadoop.fs.Path +import rdd.CoalescedRDD +import scheduler.{ResultTask, ShuffleMapTask} - +/** + * This class contains all the information of the regarding RDD checkpointing. + */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) -extends Serializable { +extends Logging with Serializable { + /** + * This class manages the state transition of an RDD through checkpointing + * [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + */ class CheckpointState extends Serializable { var state = 0 @@ -20,24 +28,30 @@ extends Serializable { } val cpState = new CheckpointState() - var cpFile: Option[String] = None - var cpRDD: Option[RDD[T]] = None - var cpRDDSplits: Seq[Split] = Nil + @transient var cpFile: Option[String] = None + @transient var cpRDD: Option[RDD[T]] = None + @transient var cpRDDSplits: Seq[Split] = Nil + // Mark the RDD for checkpointing def markForCheckpoint() = { - rdd.synchronized { cpState.mark() } + RDDCheckpointData.synchronized { cpState.mark() } } + // Is the RDD already checkpointed def isCheckpointed() = { - rdd.synchronized { cpState.isCheckpointed } + RDDCheckpointData.synchronized { cpState.isCheckpointed } } + // Get the file to which this RDD was checkpointed to as a Option def getCheckpointFile() = { - rdd.synchronized { cpFile } + RDDCheckpointData.synchronized { cpFile } } + // Do the checkpointing of the RDD. Called after the first job using that RDD is over. def doCheckpoint() { - rdd.synchronized { + // If it is marked for checkpointing AND checkpointing is not already in progress, + // then set it to be in progress, else return + RDDCheckpointData.synchronized { if (cpState.isMarked && !cpState.isInProgress) { cpState.start() } else { @@ -45,24 +59,56 @@ extends Serializable { } } + // Save to file, and reload it as an RDD val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString rdd.saveAsObjectFile(file) - val newRDD = rdd.context.objectFile[T](file, rdd.splits.size) - rdd.synchronized { - rdd.changeDependencies(newRDD) + val newRDD = { + val hadoopRDD = rdd.context.objectFile[T](file, rdd.splits.size) + + val oldSplits = rdd.splits.size + val newSplits = hadoopRDD.splits.size + + logDebug("RDD splits = " + oldSplits + " --> " + newSplits) + if (newSplits < oldSplits) { + throw new Exception("# splits after checkpointing is less than before " + + "[" + oldSplits + " --> " + newSplits) + } else if (newSplits > oldSplits) { + new CoalescedRDD(hadoopRDD, rdd.splits.size) + } else { + hadoopRDD + } + } + logDebug("New RDD has " + newRDD.splits.size + " splits") + + // Change the dependencies and splits of the RDD + RDDCheckpointData.synchronized { cpFile = Some(file) cpRDD = Some(newRDD) cpRDDSplits = newRDD.splits + rdd.changeDependencies(newRDD) cpState.finish() + RDDCheckpointData.checkpointCompleted() + logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } } + // Get preferred location of a split after checkpointing def preferredLocations(split: Split) = { - cpRDD.get.preferredLocations(split) + RDDCheckpointData.synchronized { + cpRDD.get.preferredLocations(split) + } } - def iterator(splitIndex: Int): Iterator[T] = { - cpRDD.get.iterator(cpRDDSplits(splitIndex)) + // Get iterator. This is called at the worker nodes. + def iterator(split: Split): Iterator[T] = { + rdd.firstParent[T].iterator(split) + } +} + +private[spark] object RDDCheckpointData { + def checkpointCompleted() { + ShuffleMapTask.clearCache() + ResultTask.clearCache() } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7b46bee38..654b1c2eb7 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -40,9 +40,7 @@ import spark.partial.PartialResult import spark.rdd.HadoopRDD import spark.rdd.NewHadoopRDD import spark.rdd.UnionRDD -import spark.scheduler.ShuffleMapTask -import spark.scheduler.DAGScheduler -import spark.scheduler.TaskScheduler +import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} @@ -486,6 +484,7 @@ class SparkContext( clearJars() SparkEnv.set(null) ShuffleMapTask.clearCache() + ResultTask.clearCache() logInfo("Successfully stopped SparkContext") } diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 590f9eb738..0c8cdd10dd 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -2,7 +2,7 @@ package spark.rdd import scala.collection.mutable.HashMap -import spark.Dependency +import spark.OneToOneDependency import spark.RDD import spark.SparkContext import spark.SparkEnv @@ -17,7 +17,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St extends RDD[T](sc, Nil) { @transient - val splits_ = (0 until blockIds.size).map(i => { + var splits_ : Array[Split] = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] }).toArray @@ -43,5 +43,10 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St override def preferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) + + override def changeDependencies(newRDD: RDD[_]) { + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + splits_ = newRDD.splits + } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 1d753a5168..9975e79b08 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,10 +1,27 @@ package spark.rdd import spark._ +import java.io.{ObjectOutputStream, IOException} private[spark] -class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { +class CartesianSplit( + idx: Int, + @transient rdd1: RDD[_], + @transient rdd2: RDD[_], + s1Index: Int, + s2Index: Int + ) extends Split { + var s1 = rdd1.splits(s1Index) + var s2 = rdd2.splits(s2Index) override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + s1 = rdd1.splits(s1Index) + s2 = rdd2.splits(s2Index) + oos.defaultWriteObject() + } } private[spark] @@ -23,7 +40,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { val idx = s1.index * numSplitsInRdd2 + s2.index - array(idx) = new CartesianSplit(idx, s1, s2) + array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index) } array } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 57d472666b..e4e70b13ba 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -20,11 +20,9 @@ private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, va @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - rdd.synchronized { - // Update the reference to parent split at the time of task serialization - split = rdd.splits(splitIndex) - oos.defaultWriteObject() - } + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() } } private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep @@ -42,7 +40,8 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) +class +CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator @@ -63,7 +62,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - override def dependencies = deps_ + // Pre-checkpoint dependencies deps_ should be transient (deps_) + // but post-checkpoint dependencies must not be transient (dependencies_) + override def dependencies = if (isCheckpointed) dependencies_ else deps_ @transient var splits_ : Array[Split] = { @@ -114,7 +115,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) + deps_ = null + dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) splits_ = newRDD.splits rdds = null } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 0b4499e2eb..088958942e 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -12,11 +12,9 @@ private[spark] case class CoalescedRDDSplit( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - rdd.synchronized { - // Update the reference to parent split at the time of task serialization - parents = parentsIndices.map(rdd.splits(_)) - oos.defaultWriteObject() - } + // Update the reference to parent split at the time of task serialization + parents = parentsIndices.map(rdd.splits(_)) + oos.defaultWriteObject() } } diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index a12531ea89..af54f23ebc 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -115,4 +115,8 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } + + override def checkpoint() { + // Do nothing. Hadoop RDD cannot be checkpointed. + } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index a5948dd1f1..808729f18d 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -19,11 +19,9 @@ private[spark] class UnionSplit[T: ClassManifest]( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - rdd.synchronized { - // Update the reference to parent split at the time of task serialization - split = rdd.splits(splitIndex) - oos.defaultWriteObject() - } + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() } } @@ -55,7 +53,9 @@ class UnionRDD[T: ClassManifest]( deps.toList } - override def dependencies = deps_ + // Pre-checkpoint dependencies deps_ should be transient (deps_) + // but post-checkpoint dependencies must not be transient (dependencies_) + override def dependencies = if (isCheckpointed) dependencies_ else deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() @@ -63,7 +63,8 @@ class UnionRDD[T: ClassManifest]( s.asInstanceOf[UnionSplit[T]].preferredLocations() override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD)) + deps_ = null + dependencies_ = List(new OneToOneDependency(newRDD)) splits_ = newRDD.splits rdds = null } diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 2ebd4075a2..bcb9e4956b 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -1,17 +1,73 @@ package spark.scheduler import spark._ +import java.io._ +import util.{MetadataCleaner, TimeStampedHashMap} +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +private[spark] object ResultTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + + val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.cleanup) + + def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { + synchronized { + val old = serializedInfoCache.get(stageId).orNull + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(func) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + return bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { + synchronized { + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] + return (rdd, func) + } + } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + } + } +} + private[spark] class ResultTask[T, U]( stageId: Int, - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - val partition: Int, + var rdd: RDD[T], + var func: (TaskContext, Iterator[T]) => U, + var partition: Int, @transient locs: Seq[String], val outputId: Int) - extends Task[U](stageId) { - - val split = rdd.splits(partition) + extends Task[U](stageId) with Externalizable { + + def this() = this(0, null, null, 0, null, 0) + var split = if (rdd == null) { + null + } else { + rdd.splits(partition) + } override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) @@ -21,4 +77,31 @@ private[spark] class ResultTask[T, U]( override def preferredLocations: Seq[String] = locs override def toString = "ResultTask(" + stageId + ", " + partition + ")" + + override def writeExternal(out: ObjectOutput) { + RDDCheckpointData.synchronized { + split = rdd.splits(partition) + out.writeInt(stageId) + val bytes = ResultTask.serializeInfo( + stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeInt(outputId) + out.writeObject(split) + } + } + + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) + rdd = rdd_.asInstanceOf[RDD[T]] + func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] + partition = in.readInt() + val outputId = in.readInt() + split = in.readObject().asInstanceOf[Split] + } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 683f5ebec3..5d28c40778 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -84,19 +84,22 @@ private[spark] class ShuffleMapTask( def this() = this(0, null, null, 0, null) var split = if (rdd == null) { - null - } else { + null + } else { rdd.splits(partition) } override def writeExternal(out: ObjectOutput) { - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partition) - out.writeLong(generation) - out.writeObject(split) + RDDCheckpointData.synchronized { + split = rdd.splits(partition) + out.writeInt(stageId) + val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeLong(generation) + out.writeObject(split) + } } override def readExternal(in: ObjectInput) { diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 51bd59e2b1..7b323e089c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -34,13 +34,30 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { } } + test("RDDs with one-to-one dependencies") { + testCheckpointing(_.map(x => x.toString)) + testCheckpointing(_.flatMap(x => 1 to x)) + testCheckpointing(_.filter(_ % 2 == 0)) + testCheckpointing(_.sample(false, 0.5, 0)) + testCheckpointing(_.glom()) + testCheckpointing(_.mapPartitions(_.map(_.toString))) + testCheckpointing(r => new MapPartitionsWithSplitRDD(r, + (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) + testCheckpointing(_.pipe(Seq("cat"))) + } + test("ParallelCollection") { - val parCollection = sc.makeRDD(1 to 4) + val parCollection = sc.makeRDD(1 to 4, 2) + val numSplits = parCollection.splits.size parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) + assert(parCollection.splits.length === numSplits) + assert(parCollection.splits.toList === parCollection.checkpointData.cpRDDSplits.toList) assert(parCollection.collect() === result) } @@ -49,44 +66,58 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val blockManager = SparkEnv.get.blockManager blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) val blockRDD = new BlockRDD[String](sc, Array(blockId)) + val numSplits = blockRDD.splits.size blockRDD.checkpoint() val result = blockRDD.collect() assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) + assert(blockRDD.splits.length === numSplits) + assert(blockRDD.splits.toList === blockRDD.checkpointData.cpRDDSplits.toList) assert(blockRDD.collect() === result) } - test("RDDs with one-to-one dependencies") { - testCheckpointing(_.map(x => x.toString)) - testCheckpointing(_.flatMap(x => 1 to x)) - testCheckpointing(_.filter(_ % 2 == 0)) - testCheckpointing(_.sample(false, 0.5, 0)) - testCheckpointing(_.glom()) - testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithSplitRDD(r, - (i: Int, iter: Iterator[Int]) => iter.map(_.toString) )) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) - testCheckpointing(_.pipe(Seq("cat"))) - } - test("ShuffledRDD") { - // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD testCheckpointing(rdd => { + // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner) }) } test("UnionRDD") { def otherRDD = sc.makeRDD(1 to 10, 4) + + // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed. + // Current implementation of UnionRDD has transient reference to parent RDDs, + // so only the splits will reduce in serialized size, not the RDD. testCheckpointing(_.union(otherRDD), false, true) testParentCheckpointing(_.union(otherRDD), false, true) } test("CartesianRDD") { - def otherRDD = sc.makeRDD(1 to 10, 4) - testCheckpointing(_.cartesian(otherRDD)) - testParentCheckpointing(_.cartesian(otherRDD), true, false) + def otherRDD = sc.makeRDD(1 to 10, 1) + testCheckpointing(new CartesianRDD(sc, _, otherRDD)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, + // so only the RDD will reduce in serialized size, not the splits. + testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false) + + // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after + // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. + // Note that this test is very specific to the current implementation of CartesianRDD. + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint // checkpoint that MappedRDD + val cartesian = new CartesianRDD(sc, ones, ones) + val splitBeforeCheckpoint = + serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) + cartesian.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) + assert( + (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) && + (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2), + "CartesianRDD.parents not updated after parent RDD checkpointed" + ) } test("CoalescedRDD") { @@ -94,7 +125,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, - // so does not serialize the RDD (not need to check its size). + // so only the RDD will reduce in serialized size, not the splits. testParentCheckpointing(new CoalescedRDD(_, 2), true, false) // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after @@ -145,13 +176,14 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.headOption.orNull val rddType = operatedRDD.getClass.getSimpleName + val numSplits = operatedRDD.splits.length // Find serialized sizes before and after the checkpoint val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) operatedRDD.checkpoint() val result = operatedRDD.collect() val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - + // Test whether the checkpoint file has been created assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) @@ -160,6 +192,9 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // Test whether the splits have been changed to the new Hadoop splits assert(operatedRDD.splits.toList === operatedRDD.checkpointData.cpRDDSplits.toList) + + // Test whether the number of splits is same as before + assert(operatedRDD.splits.length === numSplits) // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) @@ -168,7 +203,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // does not have any dependency to another RDD (e.g., ParallelCollection, // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. if (testRDDSize) { - println("Size of " + rddType + + logInfo("Size of " + rddType + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") assert( rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, @@ -184,7 +219,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // must be forgotten after checkpointing (to remove all reference to parent RDDs) and // replaced with the HadoopSplits of the checkpointed RDD. if (testRDDSplitSize) { - println("Size of " + rddType + " splits " + logInfo("Size of " + rddType + " splits " + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") assert( splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, @@ -294,14 +329,118 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } + /* + test("Consistency check for ResultTask") { + // Time -----------------------> + // Core 1: |<- count in thread 1, task 1 ->| |<-- checkpoint, task 1 ---->| |<- count in thread 2, task 2 ->| + // Core 2: |<- count in thread 1, task 2 ->| |<--- checkpoint, task 2 ---------->| |<- count in thread 2, task 1 ->| + // | + // checkpoint completed + sc.stop(); sc = null + System.clearProperty("spark.master.port") + + val dir = File.createTempFile("temp_", "") + dir.delete() + val ctxt = new SparkContext("local[2]", "ResultTask") + ctxt.setCheckpointDir(dir.toString) + + try { + val rdd = ctxt.makeRDD(1 to 2, 2).map(x => { + val state = CheckpointSuite.incrementState() + println("State = " + state) + if (state <= 3) { + // If executing the two tasks for the job comouting rdd.count + // of thread 1, or the first task for the recomputation due + // to checkpointing (saveing to HDFS), then do nothing + } else if (state == 4) { + // If executing the second task for the recomputation due to + // checkpointing. then prolong this task, to allow rdd.count + // of thread 2 to start before checkpoint of this RDD is completed + + Thread.sleep(1000) + println("State = " + state + " wake up") + } else { + // Else executing the tasks from thread 2 + Thread.sleep(1000) + println("State = " + state + " wake up") + } + + (x, 1) + }) + rdd.checkpoint() + val env = SparkEnv.get + + val thread1 = new Thread() { + override def run() { + try { + SparkEnv.set(env) + rdd.count() + } catch { + case e: Exception => CheckpointSuite.failed("Exception in thread 1", e) + } + } + } + thread1.start() + + val thread2 = new Thread() { + override def run() { + try { + SparkEnv.set(env) + CheckpointSuite.waitTillState(3) + println("\n\n\n\n") + rdd.count() + } catch { + case e: Exception => CheckpointSuite.failed("Exception in thread 2", e) + } + } + } + thread2.start() + + thread1.join() + thread2.join() + } finally { + dir.delete() + } + + assert(!CheckpointSuite.failed, CheckpointSuite.failureMessage) + + ctxt.stop() + + } + */ } object CheckpointSuite { + /* + var state = 0 + var failed = false + var failureMessage = "" + + def incrementState(): Int = { + this.synchronized { state += 1; this.notifyAll(); state } + } + + def getState(): Int = { + this.synchronized( state ) + } + + def waitTillState(s: Int) { + while(state < s) { + this.synchronized { this.wait() } + } + } + + def failed(msg: String, ex: Exception) { + failed = true + failureMessage += msg + "\n" + ex + "\n\n" + } + */ + // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { - println("First = " + first + ", second = " + second) + //println("First = " + first + ", second = " + second) new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]), part -- cgit v1.2.3 From 2a87d816a24c62215d682e3a7af65489c0d6e708 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 01:44:43 -0800 Subject: Added clear property to JavaAPISuite to remove port binding errors. --- core/src/test/scala/spark/JavaAPISuite.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 5875506179..6bd9836a93 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -44,6 +44,8 @@ public class JavaAPISuite implements Serializable { public void tearDown() { sc.stop(); sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port"); } static class ReverseIntComparator implements Comparator, Serializable { -- cgit v1.2.3 From fa28f25619d6712e5f920f498ec03085ea208b4d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 13:59:43 -0800 Subject: Fixed bug in UnionRDD and CoGroupedRDD --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 9 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 12 +-- core/src/test/scala/spark/CheckpointSuite.scala | 104 ----------------------- 3 files changed, 10 insertions(+), 115 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index e4e70b13ba..bc6d16ee8b 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -15,8 +15,11 @@ import spark.Split import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable -private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) - extends CoGroupSplitDep { +private[spark] case class NarrowCoGroupSplitDep( + rdd: RDD[_], + splitIndex: Int, + var split: Split + ) extends CoGroupSplitDep { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { @@ -75,7 +78,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) case s: ShuffleDependency[_, _] => new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep case _ => - new NarrowCoGroupSplitDep(r, i): CoGroupSplitDep + new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep } }.toList) } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 808729f18d..a84867492b 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -5,14 +5,10 @@ import scala.collection.mutable.ArrayBuffer import spark._ import java.io.{ObjectOutputStream, IOException} -private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, - rdd: RDD[T], - splitIndex: Int, - var split: Split = null) - extends Split - with Serializable { - +private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) + extends Split { + var split: Split = rdd.splits(splitIndex) + def iterator() = rdd.iterator(split) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 7b323e089c..909c55c91c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -329,114 +329,10 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } - /* - test("Consistency check for ResultTask") { - // Time -----------------------> - // Core 1: |<- count in thread 1, task 1 ->| |<-- checkpoint, task 1 ---->| |<- count in thread 2, task 2 ->| - // Core 2: |<- count in thread 1, task 2 ->| |<--- checkpoint, task 2 ---------->| |<- count in thread 2, task 1 ->| - // | - // checkpoint completed - sc.stop(); sc = null - System.clearProperty("spark.master.port") - - val dir = File.createTempFile("temp_", "") - dir.delete() - val ctxt = new SparkContext("local[2]", "ResultTask") - ctxt.setCheckpointDir(dir.toString) - - try { - val rdd = ctxt.makeRDD(1 to 2, 2).map(x => { - val state = CheckpointSuite.incrementState() - println("State = " + state) - if (state <= 3) { - // If executing the two tasks for the job comouting rdd.count - // of thread 1, or the first task for the recomputation due - // to checkpointing (saveing to HDFS), then do nothing - } else if (state == 4) { - // If executing the second task for the recomputation due to - // checkpointing. then prolong this task, to allow rdd.count - // of thread 2 to start before checkpoint of this RDD is completed - - Thread.sleep(1000) - println("State = " + state + " wake up") - } else { - // Else executing the tasks from thread 2 - Thread.sleep(1000) - println("State = " + state + " wake up") - } - - (x, 1) - }) - rdd.checkpoint() - val env = SparkEnv.get - - val thread1 = new Thread() { - override def run() { - try { - SparkEnv.set(env) - rdd.count() - } catch { - case e: Exception => CheckpointSuite.failed("Exception in thread 1", e) - } - } - } - thread1.start() - - val thread2 = new Thread() { - override def run() { - try { - SparkEnv.set(env) - CheckpointSuite.waitTillState(3) - println("\n\n\n\n") - rdd.count() - } catch { - case e: Exception => CheckpointSuite.failed("Exception in thread 2", e) - } - } - } - thread2.start() - - thread1.join() - thread2.join() - } finally { - dir.delete() - } - - assert(!CheckpointSuite.failed, CheckpointSuite.failureMessage) - - ctxt.stop() - - } - */ } object CheckpointSuite { - /* - var state = 0 - var failed = false - var failureMessage = "" - - def incrementState(): Int = { - this.synchronized { state += 1; this.notifyAll(); state } - } - - def getState(): Int = { - this.synchronized( state ) - } - - def waitTillState(s: Int) { - while(state < s) { - this.synchronized { this.wait() } - } - } - - def failed(msg: String, ex: Exception) { - failed = true - failureMessage += msg + "\n" + ex + "\n\n" - } - */ - // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { -- cgit v1.2.3 From 8e74fac215e8b9cda7e35111c5116e3669c6eb97 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 15:36:12 -0800 Subject: Made checkpoint data in RDDs optional to further reduce serialized size. --- core/src/main/scala/spark/RDD.scala | 19 +++++++++++-------- core/src/main/scala/spark/SparkContext.scala | 11 +++++++++++ core/src/test/scala/spark/CheckpointSuite.scala | 12 ++++++------ .../src/main/scala/spark/streaming/DStream.scala | 4 +--- 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index efa03d5185..6c04769c82 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -112,7 +112,7 @@ abstract class RDD[T: ClassManifest]( // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - protected[spark] val checkpointData = new RDDCheckpointData(this) + protected[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassManifest] = { @@ -149,7 +149,7 @@ abstract class RDD[T: ClassManifest]( def getPreferredLocations(split: Split) = { if (isCheckpointed) { - checkpointData.preferredLocations(split) + checkpointData.get.preferredLocations(split) } else { preferredLocations(split) } @@ -163,7 +163,7 @@ abstract class RDD[T: ClassManifest]( final def iterator(split: Split): Iterator[T] = { if (isCheckpointed) { // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original - checkpointData.iterator(split) + checkpointData.get.iterator(split) } else if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { @@ -516,21 +516,24 @@ abstract class RDD[T: ClassManifest]( * require recomputation. */ def checkpoint() { - checkpointData.markForCheckpoint() + if (checkpointData.isEmpty) { + checkpointData = Some(new RDDCheckpointData(this)) + checkpointData.get.markForCheckpoint() + } } /** * Return whether this RDD has been checkpointed or not */ def isCheckpointed(): Boolean = { - checkpointData.isCheckpointed() + if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false } /** * Gets the name of the file to which this RDD was checkpointed */ def getCheckpointFile(): Option[String] = { - checkpointData.getCheckpointFile() + if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None } /** @@ -539,12 +542,12 @@ abstract class RDD[T: ClassManifest]( * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. */ protected[spark] def doCheckpoint() { - checkpointData.doCheckpoint() + if (checkpointData.isDefined) checkpointData.get.doCheckpoint() dependencies.foreach(_.rdd.doCheckpoint()) } /** - * Changes the dependencies of this RDD from its original parents to the new [[spark.rdd.HadoopRDD]] + * Changes the dependencies of this RDD from its original parents to the new RDD * (`newRDD`) created from the checkpoint file. This method must ensure that all references * to the original parent RDDs must be removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own changing diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 654b1c2eb7..71ed4ef058 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -366,6 +366,17 @@ class SparkContext( .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes)) } + + protected[spark] def checkpointFile[T: ClassManifest]( + path: String, + minSplits: Int = defaultMinSplits + ): RDD[T] = { + val rdd = objectFile[T](path, minSplits) + rdd.checkpointData = Some(new RDDCheckpointData(rdd)) + rdd.checkpointData.get.cpFile = Some(path) + rdd + } + /** Build the union of a list of RDDs. */ def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 909c55c91c..0bffedb8db 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -57,7 +57,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) - assert(parCollection.splits.toList === parCollection.checkpointData.cpRDDSplits.toList) + assert(parCollection.splits.toList === parCollection.checkpointData.get.cpRDDSplits.toList) assert(parCollection.collect() === result) } @@ -72,7 +72,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) - assert(blockRDD.splits.toList === blockRDD.checkpointData.cpRDDSplits.toList) + assert(blockRDD.splits.toList === blockRDD.checkpointData.get.cpRDDSplits.toList) assert(blockRDD.collect() === result) } @@ -84,7 +84,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { } test("UnionRDD") { - def otherRDD = sc.makeRDD(1 to 10, 4) + def otherRDD = sc.makeRDD(1 to 10, 1) // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed. // Current implementation of UnionRDD has transient reference to parent RDDs, @@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) // Test whether the splits have been changed to the new Hadoop splits - assert(operatedRDD.splits.toList === operatedRDD.checkpointData.cpRDDSplits.toList) + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.cpRDDSplits.toList) // Test whether the number of splits is same as before assert(operatedRDD.splits.length === numSplits) @@ -289,8 +289,8 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { */ def generateLongLineageRDD(): RDD[Int] = { var rdd = sc.makeRDD(1 to 100, 4) - for (i <- 1 to 20) { - rdd = rdd.map(x => x) + for (i <- 1 to 50) { + rdd = rdd.map(x => x + 1) } rdd } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d290c5927e..69fefa21a0 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -372,9 +372,7 @@ extends Serializable with Logging { checkpointData.foreach { case(time, data) => { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") - val rdd = ssc.sc.objectFile[T](data.toString) - // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData() - rdd.checkpointData.cpFile = Some(data.toString) + val rdd = ssc.sc.checkpointFile[T](data.toString) generatedRDDs += ((time, rdd)) } } -- cgit v1.2.3 From 1b7a0451ed7df78838ca7ea09dfa5ba0e236acfe Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 13 Dec 2012 00:04:42 -0800 Subject: Added the ability in block manager to remove blocks. --- core/src/main/scala/spark/SparkEnv.scala | 11 +- .../main/scala/spark/storage/BlockManager.scala | 83 ++++----- .../main/scala/spark/storage/BlockManagerId.scala | 29 +++ .../scala/spark/storage/BlockManagerMaster.scala | 199 +++++++++------------ .../scala/spark/storage/BlockManagerMessages.scala | 102 +++++++++++ .../spark/storage/BlockManagerSlaveActor.scala | 16 ++ .../main/scala/spark/storage/ThreadingTest.scala | 13 +- .../main/scala/spark/util/GenerationIdUtil.scala | 19 ++ .../scala/spark/storage/BlockManagerSuite.scala | 59 ++++-- 9 files changed, 361 insertions(+), 170 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockManagerId.scala create mode 100644 core/src/main/scala/spark/storage/BlockManagerMessages.scala create mode 100644 core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala create mode 100644 core/src/main/scala/spark/util/GenerationIdUtil.scala diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 272d7cdad3..41441720a7 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -86,10 +86,13 @@ object SparkEnv extends Logging { } val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") - - val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal) + + val masterIp: String = System.getProperty("spark.master.host", "localhost") + val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val blockManagerMaster = new BlockManagerMaster( + actorSystem, isMaster, isLocal, masterIp, masterPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) - + val connectionManager = blockManager.connectionManager val broadcastManager = new BroadcastManager(isMaster) @@ -104,7 +107,7 @@ object SparkEnv extends Logging { val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") - + val httpFileServer = new HttpFileServer() httpFileServer.initialize() System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index df295b1820..b2c9e2cc40 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -1,59 +1,39 @@ package spark.storage -import akka.actor.{ActorSystem, Cancellable} +import java.io.{InputStream, OutputStream} +import java.nio.{ByteBuffer, MappedByteBuffer} +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.JavaConversions._ + +import akka.actor.{ActorSystem, Cancellable, Props} import akka.dispatch.{Await, Future} import akka.util.Duration import akka.util.duration._ -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - -import java.io.{InputStream, OutputStream, Externalizable, ObjectInput, ObjectOutput} -import java.nio.{MappedByteBuffer, ByteBuffer} -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import scala.collection.JavaConversions._ +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} import spark.network._ import spark.serializer.Serializer -import spark.util.ByteBufferInputStream -import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} -import sun.nio.ch.DirectBuffer - - -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { - def this() = this(null, 0) // For deserialization only - - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) +import spark.util.{ByteBufferInputStream, GenerationIdUtil} - override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) - } - - override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() - } - - override def toString = "BlockManagerId(" + ip + ", " + port + ")" - - override def hashCode = ip.hashCode * 41 + port +import sun.nio.ch.DirectBuffer - override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false - } -} private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) private[spark] -class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, - val serializer: Serializer, maxMemory: Long) +class BlockManager( + actorSystem: ActorSystem, + val master: BlockManagerMaster, + val serializer: Serializer, + maxMemory: Long) extends Logging { class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { @@ -110,6 +90,9 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, val host = System.getProperty("spark.hostname", Utils.localHostName()) + val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), + name = "BlockManagerActor" + GenerationIdUtil.BLOCK_MANAGER.next) + @volatile private var shuttingDown = false private def heartBeat() { @@ -134,8 +117,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, * BlockManagerWorker actor. */ private def initialize() { - master.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory)) + master.mustRegisterBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) if (!BlockManager.getDisableHeartBeatsForTesting) { heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { @@ -171,8 +153,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, def reregister() { // TODO: We might need to rate limit reregistering. logInfo("BlockManager reregistering with master") - master.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory)) + master.mustRegisterBlockManager(blockManagerId, maxMemory, slaveActor) reportAllBlocks() } @@ -865,6 +846,25 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, } } + /** + * Remove a block from both memory and disk. This one doesn't report to the master + * because it expects the master to initiate the original block removal command, and + * then the master can update the block tracking itself. + */ + def removeBlock(blockId: String) { + logInfo("Removing block " + blockId) + val info = blockInfo.get(blockId) + if (info != null) info.synchronized { + // Removals are idempotent in disk store and memory store. At worst, we get a warning. + memoryStore.remove(blockId) + diskStore.remove(blockId) + blockInfo.remove(blockId) + } else { + // The block has already been removed; do nothing. + logWarning("Block " + blockId + " does not exist.") + } + } + def shouldCompress(blockId: String): Boolean = { if (blockId.startsWith("shuffle_")) { compressShuffle @@ -914,6 +914,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, heartBeatTask.cancel() } connectionManager.stop() + master.actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diskStore.clear() diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala new file mode 100644 index 0000000000..03cd141805 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -0,0 +1,29 @@ +package spark.storage + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + + +private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { + def this() = this(null, 0) // For deserialization only + + def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + + override def writeExternal(out: ObjectOutput) { + out.writeUTF(ip) + out.writeInt(port) + } + + override def readExternal(in: ObjectInput) { + ip = in.readUTF() + port = in.readInt() + } + + override def toString = "BlockManagerId(" + ip + ", " + port + ")" + + override def hashCode = ip.hashCode * 41 + port + + override def equals(that: Any) = that match { + case id: BlockManagerId => port == id.port && ip == id.ip + case _ => false + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 0a4e68f437..64cdb86f8d 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -17,95 +17,24 @@ import spark.{Logging, SparkException, Utils} private[spark] -sealed trait ToBlockManagerMaster +case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) -private[spark] -case class RegisterBlockManager( - blockManagerId: BlockManagerId, - maxMemSize: Long) - extends ToBlockManagerMaster - -private[spark] -case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - -private[spark] -class BlockUpdate( - var blockManagerId: BlockManagerId, - var blockId: String, - var storageLevel: StorageLevel, - var memSize: Long, - var diskSize: Long) - extends ToBlockManagerMaster - with Externalizable { - - def this() = this(null, null, null, 0, 0) // For deserialization only - - override def writeExternal(out: ObjectOutput) { - blockManagerId.writeExternal(out) - out.writeUTF(blockId) - storageLevel.writeExternal(out) - out.writeInt(memSize.toInt) - out.writeInt(diskSize.toInt) - } - - override def readExternal(in: ObjectInput) { - blockManagerId = new BlockManagerId() - blockManagerId.readExternal(in) - blockId = in.readUTF() - storageLevel = new StorageLevel() - storageLevel.readExternal(in) - memSize = in.readInt() - diskSize = in.readInt() - } -} - -private[spark] -object BlockUpdate { - def apply(blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long): BlockUpdate = { - new BlockUpdate(blockManagerId, blockId, storageLevel, memSize, diskSize) - } - - // For pattern-matching - def unapply(h: BlockUpdate): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) - } -} - -private[spark] -case class GetLocations(blockId: String) extends ToBlockManagerMaster - -private[spark] -case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster - -private[spark] -case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster - -private[spark] -case class RemoveHost(host: String) extends ToBlockManagerMaster - -private[spark] -case object StopBlockManagerMaster extends ToBlockManagerMaster - -private[spark] -case object GetMemoryStatus extends ToBlockManagerMaster +// TODO(rxin): Move BlockManagerMasterActor to its own file. private[spark] -case object ExpireDeadHosts extends ToBlockManagerMaster - - -private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { +class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, - val maxMem: Long) { - private var _lastSeenMs = timeMs - private var _remainingMem = maxMem - private val _blocks = new JHashMap[String, StorageLevel] + val maxMem: Long, + val slaveActor: ActorRef) { + + private var _lastSeenMs: Long = timeMs + private var _remainingMem: Long = maxMem + + // Mapping from block id to its status. + private val _blocks = new JHashMap[String, BlockStatus] logInfo("Registering block manager %s:%d with %s RAM".format( blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) @@ -121,7 +50,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (_blocks.containsKey(blockId)) { // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId) + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel if (originalLevel.useMemory) { _remainingMem += memSize @@ -130,7 +59,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (storageLevel.isValid) { // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, storageLevel) + _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( @@ -143,15 +72,15 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. - val originalLevel: StorageLevel = _blocks.get(blockId) + val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) - if (originalLevel.useMemory) { - _remainingMem += memSize + if (blockStatus.storageLevel.useMemory) { + _remainingMem += blockStatus.memSize logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), Utils.memoryBytesToString(_remainingMem))) } - if (originalLevel.useDisk) { + if (blockStatus.storageLevel.useDisk) { logInfo("Removed %s on %s:%d on disk (size: %s)".format( blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) } @@ -162,7 +91,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor def lastSeenMs: Long = _lastSeenMs - def blocks: JHashMap[String, StorageLevel] = _blocks + def blocks: JHashMap[String, BlockStatus] = _blocks override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem @@ -171,8 +100,13 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } } + // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo] + + // Mapping from host name to block manager id. private val blockManagerIdByHost = new HashMap[String, BlockManagerId] + + // Mapping from block id to the set of block managers that have the block. private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] initLogging() @@ -245,8 +179,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } def receive = { - case RegisterBlockManager(blockManagerId, maxMemSize) => - register(blockManagerId, maxMemSize) + case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + register(blockManagerId, maxMemSize, slaveActor) case BlockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) => blockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) @@ -264,6 +198,9 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor case GetMemoryStatus => getMemoryStatus + case RemoveBlock(blockId) => + removeBlock(blockId) + case RemoveHost(host) => removeHost(host) sender ! true @@ -286,6 +223,27 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor logInfo("Got unknown message: " + other) } + // Remove a block from the slaves that have it. This can only be used to remove + // blocks that the master knows about. + private def removeBlock(blockId: String) { + val block = blockInfo.get(blockId) + if (block != null) { + block._2.foreach { blockManagerId: BlockManagerId => + val blockManager = blockManagerInfo.get(blockManagerId) + if (blockManager.isDefined) { + // Remove the block from the slave's BlockManager. + // Doesn't actually wait for a confirmation and the message might get lost. + // If message loss becomes frequent, we should add retry logic here. + blockManager.get.slaveActor ! RemoveBlock(blockId) + // Remove the block from the master's BlockManagerInfo. + blockManager.get.updateBlockInfo(blockId, StorageLevel.NONE, 0, 0) + } + } + blockInfo.remove(blockId) + } + sender ! true + } + // Return a map from the block manager id to max memory and remaining memory. private def getMemoryStatus() { val res = blockManagerInfo.map { case(blockManagerId, info) => @@ -294,7 +252,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor sender ! res } - private def register(blockManagerId: BlockManagerId, maxMemSize: Long) { + private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) @@ -309,7 +267,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor logInfo("Got Register Msg from master node, don't register it") } else { blockManagerInfo += (blockManagerId -> new BlockManagerInfo( - blockManagerId, System.currentTimeMillis(), maxMemSize)) + blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) } blockManagerIdByHost += (blockManagerId.ip -> blockManagerId) logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) @@ -442,25 +400,29 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } } -private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) + +private[spark] class BlockManagerMaster( + val actorSystem: ActorSystem, + isMaster: Boolean, + isLocal: Boolean, + masterIp: String, + masterPort: Int) extends Logging { - val AKKA_ACTOR_NAME: String = "BlockMasterManager" + val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" + val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" val REQUEST_RETRY_INTERVAL_MS = 100 - val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost") - val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt val DEFAULT_MANAGER_IP: String = Utils.localHostName() val timeout = 10.seconds var masterActor: ActorRef = null if (isMaster) { - masterActor = actorSystem.actorOf( - Props(new BlockManagerMasterActor(isLocal)), name = AKKA_ACTOR_NAME) + masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), + name = MASTER_AKKA_ACTOR_NAME) logInfo("Registered BlockManagerMaster Actor") } else { - val url = "akka://spark@%s:%s/user/%s".format( - DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME) + val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME) logInfo("Connecting to BlockManagerMaster: " + url) masterActor = actorSystem.actorFor(url) } @@ -497,7 +459,9 @@ private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Bool logInfo("Removed " + host + " successfully in notifyADeadHost") } - def mustRegisterBlockManager(msg: RegisterBlockManager) { + def mustRegisterBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + val msg = RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) logInfo("Trying to register BlockManager") while (! syncRegisterBlockManager(msg)) { logWarning("Failed to register " + msg) @@ -506,7 +470,7 @@ private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Bool logInfo("Done registering BlockManager") } - def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { + private def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { //val masterActor = RemoteActor.select(node, name) val startTimeMs = System.currentTimeMillis() val tmp = " msg " + msg + " " @@ -533,7 +497,7 @@ private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Bool return res.get } - def syncHeartBeat(msg: HeartBeat): Option[Boolean] = { + private def syncHeartBeat(msg: HeartBeat): Option[Boolean] = { try { val answer = askMaster(msg).asInstanceOf[Boolean] return Some(answer) @@ -553,7 +517,7 @@ private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Bool return res.get } - def syncBlockUpdate(msg: BlockUpdate): Option[Boolean] = { + private def syncBlockUpdate(msg: BlockUpdate): Option[Boolean] = { val startTimeMs = System.currentTimeMillis() val tmp = " msg " + msg + " " logDebug("Got in syncBlockUpdate " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) @@ -580,7 +544,7 @@ private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Bool return res } - def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = { + private def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = { val startTimeMs = System.currentTimeMillis() val tmp = " msg " + msg + " " logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) @@ -603,7 +567,7 @@ private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Bool } def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { + Seq[Seq[BlockManagerId]] = { var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg) while (res == null) { logWarning("Failed to GetLocationsMultipleBlockIds " + msg) @@ -613,7 +577,7 @@ private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Bool return res } - def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): + private def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): Seq[Seq[BlockManagerId]] = { val startTimeMs = System.currentTimeMillis val tmp = " msg " + msg + " " @@ -644,11 +608,10 @@ private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Bool Thread.sleep(REQUEST_RETRY_INTERVAL_MS) res = syncGetPeers(msg) } - - return res + res } - def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = { + private def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = { val startTimeMs = System.currentTimeMillis val tmp = " msg " + msg + " " logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) @@ -670,6 +633,20 @@ private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Bool } } + /** + * Remove a block from the slaves that have it. This can only be used to remove + * blocks that the master knows about. + */ + def removeBlock(blockId: String) { + askMaster(RemoveBlock(blockId)) + } + + /** + * Return the memory status for each block manager, in the form of a map from + * the block manager's id to two long values. The first value is the maximum + * amount of memory allocated for the block manager, while the second is the + * amount of remaining memory. + */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { askMaster(GetMemoryStatus).asInstanceOf[Map[BlockManagerId, (Long, Long)]] } diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala new file mode 100644 index 0000000000..5bca170f95 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -0,0 +1,102 @@ +package spark.storage + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import akka.actor.ActorRef + + +////////////////////////////////////////////////////////////////////////////////// +// Messages from the master to slaves. +////////////////////////////////////////////////////////////////////////////////// +private[spark] +sealed trait ToBlockManagerSlave + +// Remove a block from the slaves that have it. This can only be used to remove +// blocks that the master knows about. +private[spark] +case class RemoveBlock(blockId: String) extends ToBlockManagerSlave + + +////////////////////////////////////////////////////////////////////////////////// +// Messages from slaves to the master. +////////////////////////////////////////////////////////////////////////////////// +private[spark] +sealed trait ToBlockManagerMaster + +private[spark] +case class RegisterBlockManager( + blockManagerId: BlockManagerId, + maxMemSize: Long, + sender: ActorRef) + extends ToBlockManagerMaster + +private[spark] +case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + +private[spark] +class BlockUpdate( + var blockManagerId: BlockManagerId, + var blockId: String, + var storageLevel: StorageLevel, + var memSize: Long, + var diskSize: Long) + extends ToBlockManagerMaster + with Externalizable { + + def this() = this(null, null, null, 0, 0) // For deserialization only + + override def writeExternal(out: ObjectOutput) { + blockManagerId.writeExternal(out) + out.writeUTF(blockId) + storageLevel.writeExternal(out) + out.writeInt(memSize.toInt) + out.writeInt(diskSize.toInt) + } + + override def readExternal(in: ObjectInput) { + blockManagerId = new BlockManagerId() + blockManagerId.readExternal(in) + blockId = in.readUTF() + storageLevel = new StorageLevel() + storageLevel.readExternal(in) + memSize = in.readInt() + diskSize = in.readInt() + } +} + +private[spark] +object BlockUpdate { + def apply(blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long): BlockUpdate = { + new BlockUpdate(blockManagerId, blockId, storageLevel, memSize, diskSize) + } + + // For pattern-matching + def unapply(h: BlockUpdate): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) + } +} + +private[spark] +case class GetLocations(blockId: String) extends ToBlockManagerMaster + +private[spark] +case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster + +private[spark] +case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster + +private[spark] +case class RemoveHost(host: String) extends ToBlockManagerMaster + +private[spark] +case object StopBlockManagerMaster extends ToBlockManagerMaster + +private[spark] +case object GetMemoryStatus extends ToBlockManagerMaster + +private[spark] +case object ExpireDeadHosts extends ToBlockManagerMaster diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala new file mode 100644 index 0000000000..f570cdc52d --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala @@ -0,0 +1,16 @@ +package spark.storage + +import akka.actor.Actor + +import spark.{Logging, SparkException, Utils} + + +/** + * An actor to take commands from the master to execute options. For example, + * this is used to remove blocks from the slave's BlockManager. + */ +class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { + override def receive = { + case RemoveBlock(blockId) => blockManager.removeBlock(blockId) + } +} diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 5bb5a29cc4..689f07b969 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -58,8 +58,10 @@ private[spark] object ThreadingTest { val startTime = System.currentTimeMillis() manager.get(blockId) match { case Some(retrievedBlock) => - assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, "Block " + blockId + " did not match") - println("Got block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") + assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, + "Block " + blockId + " did not match") + println("Got block " + blockId + " in " + + (System.currentTimeMillis - startTime) + " ms") case None => assert(false, "Block " + blockId + " could not be retrieved") } @@ -73,7 +75,9 @@ private[spark] object ThreadingTest { System.setProperty("spark.kryoserializer.buffer.mb", "1") val actorSystem = ActorSystem("test") val serializer = new KryoSerializer - val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true) + val masterIp: String = System.getProperty("spark.master.host", "localhost") + val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) @@ -86,6 +90,7 @@ private[spark] object ThreadingTest { actorSystem.shutdown() actorSystem.awaitTermination() println("Everything stopped.") - println("It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") + println( + "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") } } diff --git a/core/src/main/scala/spark/util/GenerationIdUtil.scala b/core/src/main/scala/spark/util/GenerationIdUtil.scala new file mode 100644 index 0000000000..8a17b700b0 --- /dev/null +++ b/core/src/main/scala/spark/util/GenerationIdUtil.scala @@ -0,0 +1,19 @@ +package spark.util + +import java.util.concurrent.atomic.AtomicInteger + +private[spark] +object GenerationIdUtil { + + val BLOCK_MANAGER = new IdGenerator + + /** + * A util used to get a unique generation ID. This is a wrapper around + * Java's AtomicInteger. + */ + class IdGenerator { + private var id = new AtomicInteger + + def next: Int = id.incrementAndGet + } +} diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index ad2253596d..4dc3b7ec05 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -20,15 +20,15 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var oldArch: String = null var oldOops: String = null var oldHeartBeat: String = null - - // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test val serializer = new KryoSerializer before { actorSystem = ActorSystem("test") - master = new BlockManagerMaster(actorSystem, true, true) + master = new BlockManagerMaster(actorSystem, true, true, "localhost", 7077) - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case oldArch = System.setProperty("os.arch", "amd64") oldOops = System.setProperty("spark.test.useCompressedOops", "true") oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true") @@ -74,7 +74,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false) - // Checking whether blocks are in memory + // Checking whether blocks are in memory assert(store.getSingle("a1") != None, "a1 was not in store") assert(store.getSingle("a2") != None, "a2 was not in store") assert(store.getSingle("a3") != None, "a3 was not in store") @@ -83,7 +83,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") - + // Drop a1 and a2 from memory; this should be reported back to the master store.dropFromMemory("a1", null) store.dropFromMemory("a2", null) @@ -93,6 +93,45 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") } + test("removing block") { + store = new BlockManager(actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + + // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false) + + // Checking whether blocks are in memory and memory size + var memStatus = master.getMemoryStatus.head._2 + assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") + assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200") + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + + // Checking whether master knows about the blocks or not + assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") + assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") + assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + + // Remove a1 and a2 and a3. Should be no-op for a3. + master.removeBlock("a1") + master.removeBlock("a2") + master.removeBlock("a3") + assert(store.getSingle("a1") === None, "a1 not removed from store") + assert(store.getSingle("a2") === None, "a2 not removed from store") + assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") + assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + memStatus = master.getMemoryStatus.head._2 + assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") + assert(memStatus._2 == 2000L, "remaining memory " + memStatus._1 + " should equal 2000") + } + test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) store = new BlockManager(actorSystem, master, serializer, 2000) @@ -122,7 +161,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master.notifyADeadHost(store.blockManagerId.ip) assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master") - + store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) assert(master.mustGetLocations(GetLocations("a1")).size > 0, @@ -145,11 +184,11 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master") store invokePrivate heartBeat() - + assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") store2 invokePrivate heartBeat() - + assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a2 was not removed from master") } @@ -171,7 +210,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a2") != None, "a2 was not in store") assert(store.getSingle("a3") === None, "a3 was in store") } - + test("in-memory LRU storage with serialization") { store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) -- cgit v1.2.3 From 7c9e3d1c2105b694bedcfe10e554dbadd2760eb5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 13 Dec 2012 15:12:44 -0800 Subject: Return success or failure in BlockStore.remove(). --- core/src/main/scala/spark/storage/BlockManager.scala | 13 ++++++++++--- core/src/main/scala/spark/storage/BlockStore.scala | 7 ++++++- core/src/main/scala/spark/storage/DiskStore.scala | 5 ++++- core/src/main/scala/spark/storage/MemoryStore.scala | 5 +++-- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index b2c9e2cc40..9a60a8dd62 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -832,7 +832,10 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } } - memoryStore.remove(blockId) + val blockWasRemoved = memoryStore.remove(blockId) + if (!blockWasRemoved) { + logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") + } if (info.tellMaster) { reportBlockStatus(blockId) } @@ -856,8 +859,12 @@ class BlockManager( val info = blockInfo.get(blockId) if (info != null) info.synchronized { // Removals are idempotent in disk store and memory store. At worst, we get a warning. - memoryStore.remove(blockId) - diskStore.remove(blockId) + val removedFromMemory = memoryStore.remove(blockId) + val removedFromDisk = diskStore.remove(blockId) + if (!removedFromMemory && !removedFromDisk) { + logWarning("Block " + blockId + " could not be removed as it was not found in either " + + "the disk or memory store") + } blockInfo.remove(blockId) } else { // The block has already been removed; do nothing. diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 096bf8bdd9..8188d3595e 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -31,7 +31,12 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { def getValues(blockId: String): Option[Iterator[Any]] - def remove(blockId: String) + /** + * Remove a block, if it exists. + * @param blockId the block to remove. + * @return True if the block was found and removed, False otherwise. + */ + def remove(blockId: String): Boolean def contains(blockId: String): Boolean diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 8ba64e4b76..8d08871d73 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -90,10 +90,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) } - override def remove(blockId: String) { + override def remove(blockId: String): Boolean = { val file = getFile(blockId) if (file.exists()) { file.delete() + true + } else { + false } } diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 02098b82fe..00e32f753c 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -90,7 +90,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def remove(blockId: String) { + override def remove(blockId: String): Boolean = { entries.synchronized { val entry = entries.get(blockId) if (entry != null) { @@ -98,8 +98,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) currentMemory -= entry.size logInfo("Block %s of size %d dropped from memory (free %d)".format( blockId, entry.size, freeMemory)) + true } else { - logWarning("Block " + blockId + " could not be removed as it does not exist") + false } } } -- cgit v1.2.3 From 97434f49b8c029e9b78c91ec5f58557cd1b5c943 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 13 Dec 2012 22:32:19 -0800 Subject: Merged TD's block manager refactoring. --- .../main/scala/spark/storage/BlockManager.scala | 66 +- .../main/scala/spark/storage/BlockManagerId.scala | 23 +- .../scala/spark/storage/BlockManagerMaster.scala | 702 ++++----------------- .../spark/storage/BlockManagerMasterActor.scala | 406 ++++++++++++ .../scala/spark/storage/BlockManagerMessages.scala | 10 +- .../main/scala/spark/storage/StorageLevel.scala | 32 +- .../main/scala/spark/util/MetadataCleaner.scala | 35 + .../main/scala/spark/util/TimeStampedHashMap.scala | 87 +++ .../scala/spark/storage/BlockManagerSuite.scala | 91 ++- 9 files changed, 805 insertions(+), 647 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockManagerMasterActor.scala create mode 100644 core/src/main/scala/spark/util/MetadataCleaner.scala create mode 100644 core/src/main/scala/spark/util/TimeStampedHashMap.scala diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index b2c9e2cc40..2f41633440 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -19,7 +19,7 @@ import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} import spark.network._ import spark.serializer.Serializer -import spark.util.{ByteBufferInputStream, GenerationIdUtil} +import spark.util.{ByteBufferInputStream, GenerationIdUtil, MetadataCleaner, TimeStampedHashMap} import sun.nio.ch.DirectBuffer @@ -59,7 +59,7 @@ class BlockManager( } } - private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000) + private val blockInfo = new TimeStampedHashMap[String, BlockInfo]() private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -96,13 +96,14 @@ class BlockManager( @volatile private var shuttingDown = false private def heartBeat() { - if (!master.mustHeartBeat(HeartBeat(blockManagerId))) { + if (!master.sendHeartBeat(blockManagerId)) { reregister() } } var heartBeatTask: Cancellable = null + val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) initialize() /** @@ -117,7 +118,7 @@ class BlockManager( * BlockManagerWorker actor. */ private def initialize() { - master.mustRegisterBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) if (!BlockManager.getDisableHeartBeatsForTesting) { heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { @@ -153,17 +154,14 @@ class BlockManager( def reregister() { // TODO: We might need to rate limit reregistering. logInfo("BlockManager reregistering with master") - master.mustRegisterBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) reportAllBlocks() } /** * Get storage level of local block. If no info exists for the block, then returns null. */ - def getLevel(blockId: String): StorageLevel = { - val info = blockInfo.get(blockId) - if (info != null) info.level else null - } + def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull /** * Tell the master about the current storage status of a block. This will send a block update @@ -186,9 +184,9 @@ class BlockManager( */ private def tryToReportBlockStatus(blockId: String): Boolean = { val (curLevel, inMemSize, onDiskSize, tellMaster) = blockInfo.get(blockId) match { - case null => + case None => (StorageLevel.NONE, 0L, 0L, false) - case info => + case Some(info) => info.synchronized { info.level match { case null => @@ -207,7 +205,7 @@ class BlockManager( } if (tellMaster) { - master.mustBlockUpdate(BlockUpdate(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)) + master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize) } else { true } @@ -219,7 +217,7 @@ class BlockManager( */ def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis - var managers = master.mustGetLocations(GetLocations(blockId)) + var managers = master.getLocations(blockId) val locations = managers.map(_.ip) logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations @@ -230,8 +228,7 @@ class BlockManager( */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = master.mustGetLocationsMultipleBlockIds( - GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray + val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -253,7 +250,7 @@ class BlockManager( } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -338,7 +335,7 @@ class BlockManager( } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -394,7 +391,7 @@ class BlockManager( } logDebug("Getting remote block " + blockId) // Get locations of block - val locations = master.mustGetLocations(GetLocations(blockId)) + val locations = master.getLocations(blockId) // Get block from remote locations for (loc <- locations) { @@ -596,7 +593,7 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId) + val oldBlock = blockInfo.get(blockId).orNull if (oldBlock != null) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") oldBlock.waitForReady() @@ -697,7 +694,7 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.containsKey(blockId)) { + if (blockInfo.contains(blockId)) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") return } @@ -772,7 +769,7 @@ class BlockManager( val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { - cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime @@ -819,7 +816,7 @@ class BlockManager( */ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { val level = info.level @@ -853,7 +850,7 @@ class BlockManager( */ def removeBlock(blockId: String) { logInfo("Removing block " + blockId) - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) info.synchronized { // Removals are idempotent in disk store and memory store. At worst, we get a warning. memoryStore.remove(blockId) @@ -865,6 +862,29 @@ class BlockManager( } } + def dropOldBlocks(cleanupTime: Long) { + logInfo("Dropping blocks older than " + cleanupTime) + val iterator = blockInfo.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + if (time < cleanupTime) { + info.synchronized { + val level = info.level + if (level.useMemory) { + memoryStore.remove(id) + } + if (level.useDisk) { + diskStore.remove(id) + } + iterator.remove() + logInfo("Dropped block " + id) + } + reportBlockStatus(id) + } + } + } + def shouldCompress(blockId: String): Boolean = { if (blockId.startsWith("shuffle_")) { compressShuffle diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 03cd141805..488679f049 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -1,6 +1,7 @@ package spark.storage -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} +import java.util.concurrent.ConcurrentHashMap private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { @@ -18,6 +19,9 @@ private[spark] class BlockManagerId(var ip: String, var port: Int) extends Exter port = in.readInt() } + @throws(classOf[IOException]) + private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) + override def toString = "BlockManagerId(" + ip + ", " + port + ")" override def hashCode = ip.hashCode * 41 + port @@ -26,4 +30,19 @@ private[spark] class BlockManagerId(var ip: String, var port: Int) extends Exter case id: BlockManagerId => port == id.port && ip == id.ip case _ => false } -} \ No newline at end of file +} + + +private[spark] object BlockManagerId { + + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + + def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { + if (blockManagerIdCache.containsKey(id)) { + blockManagerIdCache.get(id) + } else { + blockManagerIdCache.put(id, id) + id + } + } +} diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 64cdb86f8d..cf11393a03 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -1,406 +1,17 @@ package spark.storage -import java.io._ -import java.util.{HashMap => JHashMap} - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.ArrayBuffer import scala.util.Random -import akka.actor._ -import akka.dispatch._ +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import akka.dispatch.Await import akka.pattern.ask -import akka.remote._ import akka.util.{Duration, Timeout} import akka.util.duration._ import spark.{Logging, SparkException, Utils} -private[spark] -case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - - -// TODO(rxin): Move BlockManagerMasterActor to its own file. -private[spark] -class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { - - class BlockManagerInfo( - val blockManagerId: BlockManagerId, - timeMs: Long, - val maxMem: Long, - val slaveActor: ActorRef) { - - private var _lastSeenMs: Long = timeMs - private var _remainingMem: Long = maxMem - - // Mapping from block id to its status. - private val _blocks = new JHashMap[String, BlockStatus] - - logInfo("Registering block manager %s:%d with %s RAM".format( - blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) - - def updateLastSeenMs() { - _lastSeenMs = System.currentTimeMillis() - } - - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) - : Unit = synchronized { - - updateLastSeenMs() - - if (_blocks.containsKey(blockId)) { - // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel - - if (originalLevel.useMemory) { - _remainingMem += memSize - } - } - - if (storageLevel.isValid) { - // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) - if (storageLevel.useMemory) { - _remainingMem -= memSize - logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), - Utils.memoryBytesToString(_remainingMem))) - } - if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s:%d (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) - } - } else if (_blocks.containsKey(blockId)) { - // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) - _blocks.remove(blockId) - if (blockStatus.storageLevel.useMemory) { - _remainingMem += blockStatus.memSize - logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), - Utils.memoryBytesToString(_remainingMem))) - } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s:%d on disk (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) - } - } - } - - def remainingMem: Long = _remainingMem - - def lastSeenMs: Long = _lastSeenMs - - def blocks: JHashMap[String, BlockStatus] = _blocks - - override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem - - def clear() { - _blocks.clear() - } - } - - // Mapping from block manager id to the block manager's information. - private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo] - - // Mapping from host name to block manager id. - private val blockManagerIdByHost = new HashMap[String, BlockManagerId] - - // Mapping from block id to the set of block managers that have the block. - private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] - - initLogging() - - val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong - - val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", - "5000").toLong - - var timeoutCheckingTask: Cancellable = null - - override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting) { - timeoutCheckingTask = context.system.scheduler.schedule( - 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - } - super.preStart() - } - - def removeBlockManager(blockManagerId: BlockManagerId) { - val info = blockManagerInfo(blockManagerId) - blockManagerIdByHost.remove(blockManagerId.ip) - blockManagerInfo.remove(blockManagerId) - var iterator = info.blocks.keySet.iterator - while (iterator.hasNext) { - val blockId = iterator.next - val locations = blockInfo.get(blockId)._2 - locations -= blockManagerId - if (locations.size == 0) { - blockInfo.remove(locations) - } - } - } - - def expireDeadHosts() { - logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.") - val now = System.currentTimeMillis() - val minSeenTime = now - slaveTimeout - val toRemove = new HashSet[BlockManagerId] - for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime) { - logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats") - toRemove += info.blockManagerId - } - } - // TODO: Remove corresponding block infos - toRemove.foreach(removeBlockManager) - } - - def removeHost(host: String) { - logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") - logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) - blockManagerIdByHost.get(host).foreach(removeBlockManager) - logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) - sender ! true - } - - def heartBeat(blockManagerId: BlockManagerId) { - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - sender ! true - } else { - sender ! false - } - } else { - blockManagerInfo(blockManagerId).updateLastSeenMs() - sender ! true - } - } - - def receive = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - - case BlockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) => - blockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) - - case GetLocations(blockId) => - getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId, size) => - getPeersDeterministic(blockManagerId, size) - /*getPeers(blockManagerId, size)*/ - - case GetMemoryStatus => - getMemoryStatus - - case RemoveBlock(blockId) => - removeBlock(blockId) - - case RemoveHost(host) => - removeHost(host) - sender ! true - - case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") - sender ! true - if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel - } - context.stop(self) - - case ExpireDeadHosts => - expireDeadHosts() - - case HeartBeat(blockManagerId) => - heartBeat(blockManagerId) - - case other => - logInfo("Got unknown message: " + other) - } - - // Remove a block from the slaves that have it. This can only be used to remove - // blocks that the master knows about. - private def removeBlock(blockId: String) { - val block = blockInfo.get(blockId) - if (block != null) { - block._2.foreach { blockManagerId: BlockManagerId => - val blockManager = blockManagerInfo.get(blockManagerId) - if (blockManager.isDefined) { - // Remove the block from the slave's BlockManager. - // Doesn't actually wait for a confirmation and the message might get lost. - // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor ! RemoveBlock(blockId) - // Remove the block from the master's BlockManagerInfo. - blockManager.get.updateBlockInfo(blockId, StorageLevel.NONE, 0, 0) - } - } - blockInfo.remove(blockId) - } - sender ! true - } - - // Return a map from the block manager id to max memory and remaining memory. - private def getMemoryStatus() { - val res = blockManagerInfo.map { case(blockManagerId, info) => - (blockManagerId, (info.maxMem, info.remainingMem)) - }.toMap - sender ! res - } - - private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " - logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - if (blockManagerIdByHost.contains(blockManagerId.ip) && - blockManagerIdByHost(blockManagerId.ip) != blockManagerId) { - val oldId = blockManagerIdByHost(blockManagerId.ip) - logInfo("Got second registration for host " + blockManagerId + - "; removing old slave " + oldId) - removeBlockManager(oldId) - } - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - logInfo("Got Register Msg from master node, don't register it") - } else { - blockManagerInfo += (blockManagerId -> new BlockManagerInfo( - blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) - } - blockManagerIdByHost += (blockManagerId.ip -> blockManagerId) - logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) - sender ! true - } - - private def blockUpdate( - blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long) { - - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " + blockId + " " - - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - // We intentionally do not register the master (except in local mode), - // so we should not indicate failure. - sender ! true - } else { - sender ! false - } - return - } - - if (blockId == null) { - blockManagerInfo(blockManagerId).updateLastSeenMs() - logDebug("Got in block update 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) - sender ! true - return - } - - blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) - - var locations: HashSet[BlockManagerId] = null - if (blockInfo.containsKey(blockId)) { - locations = blockInfo.get(blockId)._2 - } else { - locations = new HashSet[BlockManagerId] - blockInfo.put(blockId, (storageLevel.replication, locations)) - } - - if (storageLevel.isValid) { - locations += blockManagerId - } else { - locations.remove(blockManagerId) - } - - if (locations.size == 0) { - blockInfo.remove(blockId) - } - sender ! true - } - - private def getLocations(blockId: String) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockId + " " - logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - if (blockInfo.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " - + Utils.getUsedTimeMs(startTimeMs)) - sender ! res.toSeq - } else { - logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - sender ! res - } - } - - private def getLocationsMultipleBlockIds(blockIds: Array[String]) { - def getLocations(blockId: String): Seq[BlockManagerId] = { - val tmp = blockId - logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp) - if (blockInfo.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq) - return res.toSeq - } else { - logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp) - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - return res.toSeq - } - } - - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq) - var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] - for (blockId <- blockIds) { - res.append(getLocations(blockId)) - } - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq) - sender ! res.toSeq - } - - private def getPeers(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(peers) - res -= blockManagerId - val rand = new Random(System.currentTimeMillis()) - while (res.length > size) { - res.remove(rand.nextInt(res.length)) - } - sender ! res.toSeq - } - - private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - - val peersWithIndices = peers.zipWithIndex - val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1) - if (selfIndex == -1) { - throw new Exception("Self index for " + blockManagerId + " not found") - } - - var index = selfIndex - while (res.size < size) { - index += 1 - if (index == selfIndex) { - throw new Exception("More peer expected than available") - } - res += peers(index % peers.size) - } - sender ! res.toSeq - } -} - - private[spark] class BlockManagerMaster( val actorSystem: ActorSystem, isMaster: Boolean, @@ -409,245 +20,164 @@ private[spark] class BlockManagerMaster( masterPort: Int) extends Logging { + val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "5").toInt + val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "100").toInt + val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" - val REQUEST_RETRY_INTERVAL_MS = 100 val DEFAULT_MANAGER_IP: String = Utils.localHostName() val timeout = 10.seconds - var masterActor: ActorRef = null - - if (isMaster) { - masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), - name = MASTER_AKKA_ACTOR_NAME) - logInfo("Registered BlockManagerMaster Actor") - } else { - val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME) - logInfo("Connecting to BlockManagerMaster: " + url) - masterActor = actorSystem.actorFor(url) - } - - def stop() { - if (masterActor != null) { - communicate(StopBlockManagerMaster) - masterActor = null - logInfo("BlockManagerMaster stopped") - } - } - - // Send a message to the master actor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askMaster(message: Any): Any = { - try { - val future = masterActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with BlockManagerMaster", e) + var masterActor: ActorRef = { + if (isMaster) { + val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), + name = MASTER_AKKA_ACTOR_NAME) + logInfo("Registered BlockManagerMaster Actor") + masterActor + } else { + val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME) + logInfo("Connecting to BlockManagerMaster: " + url) + actorSystem.actorFor(url) } } - // Send a one-way message to the master actor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askMaster(message) != true) { - throw new SparkException("Error reply received from BlockManagerMaster") - } - } + /** Remove a dead host from the master actor. This is only called on the master side. */ def notifyADeadHost(host: String) { - communicate(RemoveHost(host)) + tell(RemoveHost(host)) logInfo("Removed " + host + " successfully in notifyADeadHost") } - def mustRegisterBlockManager( + /** + * Send the master actor a heart beat from the slave. Returns true if everything works out, + * false if the master does not know about the given block manager, which means the block + * manager should re-register. + */ + def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { + askMasterWithRetry[Boolean](HeartBeat(blockManagerId)) + } + + /** Register the BlockManager's id with the master. */ + def registerBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - val msg = RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) logInfo("Trying to register BlockManager") - while (! syncRegisterBlockManager(msg)) { - logWarning("Failed to register " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - logInfo("Done registering BlockManager") + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + logInfo("Registered BlockManager") } - private def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { - //val masterActor = RemoteActor.select(node, name) - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logInfo("BlockManager registered successfully @ syncRegisterBlockManager") - logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncRegisterBlockManager", e) - return false - } + def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long): Boolean = { + val res = askMasterWithRetry[Boolean]( + UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) + logInfo("Updated info of block " + blockId) + res } - def mustHeartBeat(msg: HeartBeat): Boolean = { - var res = syncHeartBeat(msg) - while (!res.isDefined) { - logWarning("Failed to send heart beat " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - return res.get + /** Get locations of the blockId from the master */ + def getLocations(blockId: String): Seq[BlockManagerId] = { + askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) } - private def syncHeartBeat(msg: HeartBeat): Option[Boolean] = { - try { - val answer = askMaster(msg).asInstanceOf[Boolean] - return Some(answer) - } catch { - case e: Exception => - logError("Failed in syncHeartBeat", e) - return None - } + /** Get locations of multiple blockIds from the master */ + def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } - def mustBlockUpdate(msg: BlockUpdate): Boolean = { - var res = syncBlockUpdate(msg) - while (!res.isDefined) { - logWarning("Failed to send block update " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) + /** Get ids of other nodes in the cluster from the master */ + def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { + val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + if (result.length != numPeers) { + throw new SparkException( + "Error getting peers, only got " + result.size + " instead of " + numPeers) } - return res.get + result } - private def syncBlockUpdate(msg: BlockUpdate): Option[Boolean] = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncBlockUpdate " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Boolean] - logDebug("Block update sent successfully") - logDebug("Got in synbBlockUpdate " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) - return Some(answer) - } catch { - case e: Exception => - logError("Failed in syncBlockUpdate", e) - return None - } + /** + * Remove a block from the slaves that have it. This can only be used to remove + * blocks that the master knows about. + */ + def removeBlock(blockId: String) { + askMaster(RemoveBlock(blockId)) } - def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - var res = syncGetLocations(msg) - while (res == null) { - logInfo("Failed to get locations " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocations(msg) - } - return res + /** + * Return the memory status for each block manager, in the form of a map from + * the block manager's id to two long values. The first value is the maximum + * amount of memory allocated for the block manager, while the second is the + * amount of remaining memory. + */ + def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { + askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } - private def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[ArrayBuffer[BlockManagerId]] - if (answer != null) { - logDebug("GetLocations successful") - logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocations") - return null - } - } catch { - case e: Exception => - logError("GetLocations failed", e) - return null + /** Stop the master actor, called only on the Spark master node */ + def stop() { + if (masterActor != null) { + tell(StopBlockManagerMaster) + masterActor = null + logInfo("BlockManagerMaster stopped") } } - def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg) - while (res == null) { - logWarning("Failed to GetLocationsMultipleBlockIds " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocationsMultipleBlockIds(msg) + /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + private def tell(message: Any) { + if (!askMasterWithRetry[Boolean](message)) { + throw new SparkException("BlockManagerMasterActor returned false, expected true.") } - return res } - private def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - + /** + * Send a message to the master actor and get its result within a default timeout, or + * throw a SparkException if this fails. There is no retry logic here so if the Akka + * message is lost, the master actor won't get the command. + */ + private def askMaster[T](message: Any): Any = { try { - val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]] - if (answer != null) { - logDebug("GetLocationsMultipleBlockIds successful") - logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + - Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocationsMultipleBlockIds") - return null - } + val future = masterActor.ask(message)(timeout) + return Await.result(future, timeout).asInstanceOf[T] } catch { case e: Exception => - logError("GetLocationsMultipleBlockIds failed", e) - return null + throw new SparkException("Error communicating with BlockManagerMaster", e) } } - def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - var res = syncGetPeers(msg) - while ((res == null) || (res.length != msg.size)) { - logInfo("Failed to get peers " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetPeers(msg) - } - res - } - - private def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[BlockManagerId]] - if (answer != null) { - logDebug("GetPeers successful") - logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetPeers") - return null + /** + * Send a message to the master actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + private def askMasterWithRetry[T](message: Any): T = { + // TODO: Consider removing multiple attempts + if (masterActor == null) { + throw new SparkException("Error sending message to BlockManager as masterActor is null " + + "[message = " + message + "]") + } + var attempts = 0 + var lastException: Exception = null + while (attempts < AKKA_RETRY_ATTEMPS) { + attempts += 1 + try { + val future = masterActor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new Exception("BlockManagerMaster returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) } - } catch { - case e: Exception => - logError("GetPeers failed", e) - return null + Thread.sleep(AKKA_RETRY_INTERVAL_MS) } - } - /** - * Remove a block from the slaves that have it. This can only be used to remove - * blocks that the master knows about. - */ - def removeBlock(blockId: String) { - askMaster(RemoveBlock(blockId)) + throw new SparkException( + "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) } - /** - * Return the memory status for each block manager, in the form of a map from - * the block manager's id to two long values. The first value is the maximum - * amount of memory allocated for the block manager, while the second is the - * amount of remaining memory. - */ - def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askMaster(GetMemoryStatus).asInstanceOf[Map[BlockManagerId, (Long, Long)]] - } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala new file mode 100644 index 0000000000..0d84e559cb --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -0,0 +1,406 @@ +package spark.storage + +import java.util.{HashMap => JHashMap} + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ +import scala.util.Random + +import akka.actor.{Actor, ActorRef, Cancellable} +import akka.util.{Duration, Timeout} +import akka.util.duration._ + +import spark.{Logging, Utils} + +/** + * BlockManagerMasterActor is an actor on the master node to track statuses of + * all slaves' block managers. + */ + +private[spark] +object BlockManagerMasterActor { + + case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) + + class BlockManagerInfo( + val blockManagerId: BlockManagerId, + timeMs: Long, + val maxMem: Long, + val slaveActor: ActorRef) + extends Logging { + + private var _lastSeenMs: Long = timeMs + private var _remainingMem: Long = maxMem + + // Mapping from block id to its status. + private val _blocks = new JHashMap[String, BlockStatus] + + logInfo("Registering block manager %s:%d with %s RAM".format( + blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) + + def updateLastSeenMs() { + _lastSeenMs = System.currentTimeMillis() + } + + def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) + : Unit = synchronized { + + updateLastSeenMs() + + if (_blocks.containsKey(blockId)) { + // The block exists on the slave already. + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel + + if (originalLevel.useMemory) { + _remainingMem += memSize + } + } + + if (storageLevel.isValid) { + // isValid means it is either stored in-memory or on-disk. + _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + if (storageLevel.useMemory) { + _remainingMem -= memSize + logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(_remainingMem))) + } + if (storageLevel.useDisk) { + logInfo("Added %s on disk on %s:%d (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + } + } else if (_blocks.containsKey(blockId)) { + // If isValid is not true, drop the block. + val blockStatus: BlockStatus = _blocks.get(blockId) + _blocks.remove(blockId) + if (blockStatus.storageLevel.useMemory) { + _remainingMem += blockStatus.memSize + logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(_remainingMem))) + } + if (blockStatus.storageLevel.useDisk) { + logInfo("Removed %s on %s:%d on disk (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + } + } + } + + def remainingMem: Long = _remainingMem + + def lastSeenMs: Long = _lastSeenMs + + def blocks: JHashMap[String, BlockStatus] = _blocks + + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem + + def clear() { + _blocks.clear() + } + } +} + + +private[spark] +class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { + + // Mapping from block manager id to the block manager's information. + private val blockManagerInfo = + new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] + + // Mapping from host name to block manager id. + private val blockManagerIdByHost = new HashMap[String, BlockManagerId] + + // Mapping from block id to the set of block managers that have the block. + private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] + + initLogging() + + val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs", + "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong + + val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", + "5000").toLong + + var timeoutCheckingTask: Cancellable = null + + override def preStart() { + if (!BlockManager.getDisableHeartBeatsForTesting) { + timeoutCheckingTask = context.system.scheduler.schedule( + 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) + } + super.preStart() + } + + def removeBlockManager(blockManagerId: BlockManagerId) { + val info = blockManagerInfo(blockManagerId) + blockManagerIdByHost.remove(blockManagerId.ip) + blockManagerInfo.remove(blockManagerId) + var iterator = info.blocks.keySet.iterator + while (iterator.hasNext) { + val blockId = iterator.next + val locations = blockInfo.get(blockId)._2 + locations -= blockManagerId + if (locations.size == 0) { + blockInfo.remove(locations) + } + } + } + + def expireDeadHosts() { + logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.") + val now = System.currentTimeMillis() + val minSeenTime = now - slaveTimeout + val toRemove = new HashSet[BlockManagerId] + for (info <- blockManagerInfo.values) { + if (info.lastSeenMs < minSeenTime) { + logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats") + toRemove += info.blockManagerId + } + } + // TODO: Remove corresponding block infos + toRemove.foreach(removeBlockManager) + } + + def removeHost(host: String) { + logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") + logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) + blockManagerIdByHost.get(host).foreach(removeBlockManager) + logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) + sender ! true + } + + def heartBeat(blockManagerId: BlockManagerId) { + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + sender ! true + } else { + sender ! false + } + } else { + blockManagerInfo(blockManagerId).updateLastSeenMs() + sender ! true + } + } + + def receive = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + register(blockManagerId, maxMemSize, slaveActor) + + case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + blockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) + + case GetLocations(blockId) => + getLocations(blockId) + + case GetLocationsMultipleBlockIds(blockIds) => + getLocationsMultipleBlockIds(blockIds) + + case GetPeers(blockManagerId, size) => + getPeersDeterministic(blockManagerId, size) + /*getPeers(blockManagerId, size)*/ + + case GetMemoryStatus => + getMemoryStatus + + case RemoveBlock(blockId) => + removeBlock(blockId) + + case RemoveHost(host) => + removeHost(host) + sender ! true + + case StopBlockManagerMaster => + logInfo("Stopping BlockManagerMaster") + sender ! true + if (timeoutCheckingTask != null) { + timeoutCheckingTask.cancel + } + context.stop(self) + + case ExpireDeadHosts => + expireDeadHosts() + + case HeartBeat(blockManagerId) => + heartBeat(blockManagerId) + + case other => + logInfo("Got unknown message: " + other) + } + + // Remove a block from the slaves that have it. This can only be used to remove + // blocks that the master knows about. + private def removeBlock(blockId: String) { + val block = blockInfo.get(blockId) + if (block != null) { + block._2.foreach { blockManagerId: BlockManagerId => + val blockManager = blockManagerInfo.get(blockManagerId) + if (blockManager.isDefined) { + // Remove the block from the slave's BlockManager. + // Doesn't actually wait for a confirmation and the message might get lost. + // If message loss becomes frequent, we should add retry logic here. + blockManager.get.slaveActor ! RemoveBlock(blockId) + // Remove the block from the master's BlockManagerInfo. + blockManager.get.updateBlockInfo(blockId, StorageLevel.NONE, 0, 0) + } + } + blockInfo.remove(blockId) + } + sender ! true + } + + // Return a map from the block manager id to max memory and remaining memory. + private def getMemoryStatus() { + val res = blockManagerInfo.map { case(blockManagerId, info) => + (blockManagerId, (info.maxMem, info.remainingMem)) + }.toMap + sender ! res + } + + private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockManagerId + " " + logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) + if (blockManagerIdByHost.contains(blockManagerId.ip) && + blockManagerIdByHost(blockManagerId.ip) != blockManagerId) { + val oldId = blockManagerIdByHost(blockManagerId.ip) + logInfo("Got second registration for host " + blockManagerId + + "; removing old slave " + oldId) + removeBlockManager(oldId) + } + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + logInfo("Got Register Msg from master node, don't register it") + } else { + blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo( + blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) + } + blockManagerIdByHost += (blockManagerId.ip -> blockManagerId) + logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) + sender ! true + } + + private def blockUpdate( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long) { + + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockManagerId + " " + blockId + " " + + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + // We intentionally do not register the master (except in local mode), + // so we should not indicate failure. + sender ! true + } else { + sender ! false + } + return + } + + if (blockId == null) { + blockManagerInfo(blockManagerId).updateLastSeenMs() + logDebug("Got in block update 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) + sender ! true + return + } + + blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) + + var locations: HashSet[BlockManagerId] = null + if (blockInfo.containsKey(blockId)) { + locations = blockInfo.get(blockId)._2 + } else { + locations = new HashSet[BlockManagerId] + blockInfo.put(blockId, (storageLevel.replication, locations)) + } + + if (storageLevel.isValid) { + locations += blockManagerId + } else { + locations.remove(blockManagerId) + } + + if (locations.size == 0) { + blockInfo.remove(blockId) + } + sender ! true + } + + private def getLocations(blockId: String) { + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockId + " " + logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) + if (blockInfo.containsKey(blockId)) { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(blockInfo.get(blockId)._2) + logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " + + Utils.getUsedTimeMs(startTimeMs)) + sender ! res.toSeq + } else { + logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + sender ! res + } + } + + private def getLocationsMultipleBlockIds(blockIds: Array[String]) { + def getLocations(blockId: String): Seq[BlockManagerId] = { + val tmp = blockId + logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp) + if (blockInfo.containsKey(blockId)) { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(blockInfo.get(blockId)._2) + logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq) + return res.toSeq + } else { + logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp) + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + return res.toSeq + } + } + + logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq) + var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] + for (blockId <- blockIds) { + res.append(getLocations(blockId)) + } + logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq) + sender ! res.toSeq + } + + private def getPeers(blockManagerId: BlockManagerId, size: Int) { + var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(peers) + res -= blockManagerId + val rand = new Random(System.currentTimeMillis()) + while (res.length > size) { + res.remove(rand.nextInt(res.length)) + } + sender ! res.toSeq + } + + private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) { + var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + + val peersWithIndices = peers.zipWithIndex + val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1) + if (selfIndex == -1) { + throw new Exception("Self index for " + blockManagerId + " not found") + } + + var index = selfIndex + while (res.size < size) { + index += 1 + if (index == selfIndex) { + throw new Exception("More peer expected than available") + } + res += peers(index % peers.size) + } + sender ! res.toSeq + } +} diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index 5bca170f95..d73a9b790f 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -34,7 +34,7 @@ private[spark] case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster private[spark] -class BlockUpdate( +class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: String, var storageLevel: StorageLevel, @@ -65,17 +65,17 @@ class BlockUpdate( } private[spark] -object BlockUpdate { +object UpdateBlockInfo { def apply(blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long): BlockUpdate = { - new BlockUpdate(blockManagerId, blockId, storageLevel, memSize, diskSize) + diskSize: Long): UpdateBlockInfo = { + new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) } // For pattern-matching - def unapply(h: BlockUpdate): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) } } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..e3544e5aae 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -1,6 +1,6 @@ package spark.storage -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} /** * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, @@ -10,14 +10,16 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} * commonly useful storage levels. */ class StorageLevel( - var useDisk: Boolean, + var useDisk: Boolean, var useMemory: Boolean, var deserialized: Boolean, var replication: Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - + + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -29,14 +31,14 @@ class StorageLevel( override def equals(other: Any): Boolean = other match { case s: StorageLevel => - s.useDisk == useDisk && + s.useDisk == useDisk && s.useMemory == useMemory && s.deserialized == deserialized && - s.replication == replication + s.replication == replication case _ => false } - + def isValid = ((useMemory || useDisk) && (replication > 0)) def toInt: Int = { @@ -66,10 +68,16 @@ class StorageLevel( replication = in.readByte() } + @throws(classOf[IOException]) + private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) + override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) + + override def hashCode(): Int = toInt * 41 + replication } + object StorageLevel { val NONE = new StorageLevel(false, false, false) val DISK_ONLY = new StorageLevel(true, false, false) @@ -82,4 +90,16 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + + private[spark] + val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + + private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { + if (storageLevelCache.containsKey(level)) { + storageLevelCache.get(level) + } else { + storageLevelCache.put(level, level) + level + } + } } diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala new file mode 100644 index 0000000000..19e67acd0c --- /dev/null +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -0,0 +1,35 @@ +package spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import spark.Logging + +class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { + + val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val periodSeconds = math.max(10, delaySeconds / 10) + val timer = new Timer(name + " cleanup timer", true) + + val task = new TimerTask { + def run() { + try { + if (delaySeconds > 0) { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) + } + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + if (periodSeconds > 0) { + logInfo( + "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " + + "period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + } + + def cancel() { + timer.cancel() + } +} diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala new file mode 100644 index 0000000000..070ee19ac0 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -0,0 +1,87 @@ +package spark.util + +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConversions._ +import scala.collection.mutable.{HashMap, Map} + +/** + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * time stamp along with each key-value pair. Key-value pairs that are older than a particular + * threshold time can them be removed using the cleanup method. This is intended to be a drop-in + * replacement of scala.collection.mutable.HashMap. + */ +class TimeStampedHashMap[A, B] extends Map[A, B]() { + val internalMap = new ConcurrentHashMap[A, (B, Long)]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null) Some(value._1) else None + } + + def iterator: Iterator[(A, B)] = { + val jIterator = internalMap.entrySet().iterator() + jIterator.map(kv => (kv.getKey, kv.getValue._1)) + } + + override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.put(kv._1, (kv._2, currentTime)) + newMap + } + + override def - (key: A): Map[A, B] = { + internalMap.remove(key) + this + } + + override def += (kv: (A, B)): this.type = { + internalMap.put(kv._1, (kv._2, currentTime)) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + val value = internalMap.get(key) + if (value == null) throw new NoSuchElementException() + value._1 + } + + override def filter(p: ((A, B)) => Boolean): Map[A, B] = { + internalMap.map(kv => (kv._1, kv._2._1)).filter(p) + } + + override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: ((A, B)) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue._1) + f(kv) + } + } + + def cleanup(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue._2 < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() + +} diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 4dc3b7ec05..e50ce1430f 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -22,6 +22,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var oldHeartBeat: String = null // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + System.setProperty("spark.kryoserializer.buffer.mb", "1") val serializer = new KryoSerializer before { @@ -63,7 +64,33 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } - test("manager-master interaction") { + test("StorageLevel object caching") { + val level1 = new StorageLevel(false, false, false, 3) + val level2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(level1) + val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(level2) + val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(level1_ === level1, "Deserialized level1 not same as original level1") + assert(level2_ === level2, "Deserialized level2 not same as original level1") + assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") + assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + } + + test("BlockManagerId object caching") { + val id1 = new StorageLevel(false, false, false, 3) + val id2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(id1) + val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(id2) + val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(id1_ === id1, "Deserialized id1 not same as original id1") + assert(id2_ === id2, "Deserialized id2 not same as original id1") + assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") + assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + } + + test("master + 1 manager interaction") { store = new BlockManager(actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -80,17 +107,33 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a3") != None, "a3 was not in store") // Checking whether master knows about the blocks or not - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") - assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") - assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + assert(master.getLocations("a1").size > 0, "master was not told about a1") + assert(master.getLocations("a2").size > 0, "master was not told about a2") + assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master store.dropFromMemory("a1", null) store.dropFromMemory("a2", null) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") - assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") + assert(master.getLocations("a1").size === 0, "master did not remove a1") + assert(master.getLocations("a2").size === 0, "master did not remove a2") + } + + test("master + 2 managers interaction") { + store = new BlockManager(actorSystem, master, serializer, 2000) + val otherStore = new BlockManager(actorSystem, master, new KryoSerializer, 2000) + + val peers = master.getPeers(store.blockManagerId, 1) + assert(peers.size === 1, "master did not return the other manager as a peer") + assert(peers.head === otherStore.blockManagerId, "peer returned by master is not the other manager") + + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) + otherStore.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) + assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") + assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") } test("removing block") { @@ -113,9 +156,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a3") != None, "a3 was not in store") // Checking whether master knows about the blocks or not - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") - assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") - assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + assert(master.getLocations("a1").size > 0, "master was not told about a1") + assert(master.getLocations("a2").size > 0, "master was not told about a2") + assert(master.getLocations("a3").size === 0, "master was told about a3") // Remove a1 and a2 and a3. Should be no-op for a3. master.removeBlock("a1") @@ -123,10 +166,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master.removeBlock("a3") assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") - assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") + assert(master.getLocations("a1").size === 0, "master did not remove a1") + assert(master.getLocations("a2").size === 0, "master did not remove a2") assert(store.getSingle("a3") != None, "a3 was not in store") - assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") + assert(master.getLocations("a3").size === 0, "master was told about a3") memStatus = master.getMemoryStatus.head._2 assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") assert(memStatus._2 == 2000L, "remaining memory " + memStatus._1 + " should equal 2000") @@ -140,13 +183,13 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") != None, "a1 was not in store") - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") + assert(master.getLocations("a1").size > 0, "master was not told about a1") master.notifyADeadHost(store.blockManagerId.ip) - assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master") + assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store invokePrivate heartBeat() - assert(master.mustGetLocations(GetLocations("a1")).size > 0, + assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") } @@ -157,17 +200,15 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") + assert(master.getLocations("a1").size > 0, "master was not told about a1") master.notifyADeadHost(store.blockManagerId.ip) - assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master") + assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) - assert(master.mustGetLocations(GetLocations("a1")).size > 0, - "a1 was not reregistered with master") - assert(master.mustGetLocations(GetLocations("a2")).size > 0, - "master was not told about a2") + assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + assert(master.getLocations("a2").size > 0, "master was not told about a2") } test("deregistration on duplicate") { @@ -177,19 +218,19 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") + assert(master.getLocations("a1").size > 0, "master was not told about a1") store2 = new BlockManager(actorSystem, master, serializer, 2000) - assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master") + assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store invokePrivate heartBeat() - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") + assert(master.getLocations("a1").size > 0, "master was not told about a1") store2 invokePrivate heartBeat() - assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a2 was not removed from master") + assert(master.getLocations("a1").size == 0, "a2 was not removed from master") } test("in-memory LRU storage") { -- cgit v1.2.3 From 8c01295b859c35f4034528d4487a45c34728d0fb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Dec 2012 00:26:36 -0800 Subject: Fixed conflicts from merging Charles' and TD's block manager changes. --- .../scala/spark/storage/BlockManagerMaster.scala | 1 - .../spark/storage/BlockManagerMasterActor.scala | 299 +++++++++++---------- .../scala/spark/storage/BlockManagerSuite.scala | 31 +-- 3 files changed, 158 insertions(+), 173 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index cf11393a03..e8a1e5889f 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -41,7 +41,6 @@ private[spark] class BlockManagerMaster( } } - /** Remove a dead host from the master actor. This is only called on the master side. */ def notifyADeadHost(host: String) { tell(RemoveHost(host)) diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index 0d84e559cb..e3de8d8e4e 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -16,91 +16,6 @@ import spark.{Logging, Utils} * BlockManagerMasterActor is an actor on the master node to track statuses of * all slaves' block managers. */ - -private[spark] -object BlockManagerMasterActor { - - case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - - class BlockManagerInfo( - val blockManagerId: BlockManagerId, - timeMs: Long, - val maxMem: Long, - val slaveActor: ActorRef) - extends Logging { - - private var _lastSeenMs: Long = timeMs - private var _remainingMem: Long = maxMem - - // Mapping from block id to its status. - private val _blocks = new JHashMap[String, BlockStatus] - - logInfo("Registering block manager %s:%d with %s RAM".format( - blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) - - def updateLastSeenMs() { - _lastSeenMs = System.currentTimeMillis() - } - - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) - : Unit = synchronized { - - updateLastSeenMs() - - if (_blocks.containsKey(blockId)) { - // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel - - if (originalLevel.useMemory) { - _remainingMem += memSize - } - } - - if (storageLevel.isValid) { - // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) - if (storageLevel.useMemory) { - _remainingMem -= memSize - logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), - Utils.memoryBytesToString(_remainingMem))) - } - if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s:%d (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) - } - } else if (_blocks.containsKey(blockId)) { - // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) - _blocks.remove(blockId) - if (blockStatus.storageLevel.useMemory) { - _remainingMem += blockStatus.memSize - logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), - Utils.memoryBytesToString(_remainingMem))) - } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s:%d on disk (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) - } - } - } - - def remainingMem: Long = _remainingMem - - def lastSeenMs: Long = _lastSeenMs - - def blocks: JHashMap[String, BlockStatus] = _blocks - - override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem - - def clear() { - _blocks.clear() - } - } -} - - private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { @@ -108,8 +23,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] - // Mapping from host name to block manager id. - private val blockManagerIdByHost = new HashMap[String, BlockManagerId] + // Mapping from host name to block manager id. We allow multiple block managers + // on the same host name (ip). + private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]] // Mapping from block id to the set of block managers that have the block. private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] @@ -132,9 +48,62 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { super.preStart() } + def receive = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + register(blockManagerId, maxMemSize, slaveActor) + + case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + blockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) + + case GetLocations(blockId) => + getLocations(blockId) + + case GetLocationsMultipleBlockIds(blockIds) => + getLocationsMultipleBlockIds(blockIds) + + case GetPeers(blockManagerId, size) => + getPeersDeterministic(blockManagerId, size) + /*getPeers(blockManagerId, size)*/ + + case GetMemoryStatus => + getMemoryStatus + + case RemoveBlock(blockId) => + removeBlock(blockId) + + case RemoveHost(host) => + removeHost(host) + sender ! true + + case StopBlockManagerMaster => + logInfo("Stopping BlockManagerMaster") + sender ! true + if (timeoutCheckingTask != null) { + timeoutCheckingTask.cancel + } + context.stop(self) + + case ExpireDeadHosts => + expireDeadHosts() + + case HeartBeat(blockManagerId) => + heartBeat(blockManagerId) + + case other => + logInfo("Got unknown message: " + other) + } + def removeBlockManager(blockManagerId: BlockManagerId) { val info = blockManagerInfo(blockManagerId) - blockManagerIdByHost.remove(blockManagerId.ip) + + // Remove the block manager from blockManagerIdByHost. If the list of block + // managers belonging to the IP is empty, remove the entry from the hash map. + blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] => + managers -= blockManagerId + if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip) + } + + // Remove it from blockManagerInfo and remove all the blocks. blockManagerInfo.remove(blockManagerId) var iterator = info.blocks.keySet.iterator while (iterator.hasNext) { @@ -158,14 +127,13 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { toRemove += info.blockManagerId } } - // TODO: Remove corresponding block infos toRemove.foreach(removeBlockManager) } def removeHost(host: String) { logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) - blockManagerIdByHost.get(host).foreach(removeBlockManager) + blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager)) logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) sender ! true } @@ -183,51 +151,6 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } } - def receive = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - - case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => - blockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) - - case GetLocations(blockId) => - getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId, size) => - getPeersDeterministic(blockManagerId, size) - /*getPeers(blockManagerId, size)*/ - - case GetMemoryStatus => - getMemoryStatus - - case RemoveBlock(blockId) => - removeBlock(blockId) - - case RemoveHost(host) => - removeHost(host) - sender ! true - - case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") - sender ! true - if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel - } - context.stop(self) - - case ExpireDeadHosts => - expireDeadHosts() - - case HeartBeat(blockManagerId) => - heartBeat(blockManagerId) - - case other => - logInfo("Got unknown message: " + other) - } - // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. private def removeBlock(blockId: String) { @@ -261,20 +184,22 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - if (blockManagerIdByHost.contains(blockManagerId.ip) && - blockManagerIdByHost(blockManagerId.ip) != blockManagerId) { - val oldId = blockManagerIdByHost(blockManagerId.ip) - logInfo("Got second registration for host " + blockManagerId + - "; removing old slave " + oldId) - removeBlockManager(oldId) - } + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { logInfo("Got Register Msg from master node, don't register it") } else { + blockManagerIdByHost.get(blockManagerId.ip) match { + case Some(managers) => + // A block manager of the same host name already exists. + logInfo("Got another registration for host " + blockManagerId) + managers += blockManagerId + case None => + blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId)) + } + blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo( blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) } - blockManagerIdByHost += (blockManagerId.ip -> blockManagerId) logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) sender ! true } @@ -387,12 +312,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - val peersWithIndices = peers.zipWithIndex - val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1) + val selfIndex = peers.indexOf(blockManagerId) if (selfIndex == -1) { throw new Exception("Self index for " + blockManagerId + " not found") } + // Note that this logic will select the same node multiple times if there aren't enough peers var index = selfIndex while (res.size < size) { index += 1 @@ -404,3 +329,87 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! res.toSeq } } + + +private[spark] +object BlockManagerMasterActor { + + case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) + + class BlockManagerInfo( + val blockManagerId: BlockManagerId, + timeMs: Long, + val maxMem: Long, + val slaveActor: ActorRef) + extends Logging { + + private var _lastSeenMs: Long = timeMs + private var _remainingMem: Long = maxMem + + // Mapping from block id to its status. + private val _blocks = new JHashMap[String, BlockStatus] + + logInfo("Registering block manager %s:%d with %s RAM".format( + blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) + + def updateLastSeenMs() { + _lastSeenMs = System.currentTimeMillis() + } + + def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) + : Unit = synchronized { + + updateLastSeenMs() + + if (_blocks.containsKey(blockId)) { + // The block exists on the slave already. + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel + + if (originalLevel.useMemory) { + _remainingMem += memSize + } + } + + if (storageLevel.isValid) { + // isValid means it is either stored in-memory or on-disk. + _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + if (storageLevel.useMemory) { + _remainingMem -= memSize + logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(_remainingMem))) + } + if (storageLevel.useDisk) { + logInfo("Added %s on disk on %s:%d (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + } + } else if (_blocks.containsKey(blockId)) { + // If isValid is not true, drop the block. + val blockStatus: BlockStatus = _blocks.get(blockId) + _blocks.remove(blockId) + if (blockStatus.storageLevel.useMemory) { + _remainingMem += blockStatus.memSize + logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(_remainingMem))) + } + if (blockStatus.storageLevel.useDisk) { + logInfo("Removed %s on %s:%d on disk (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + } + } + } + + def remainingMem: Long = _remainingMem + + def lastSeenMs: Long = _lastSeenMs + + def blocks: JHashMap[String, BlockStatus] = _blocks + + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem + + def clear() { + _blocks.clear() + } + } +} diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index e50ce1430f..4e28a7e2bc 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -122,16 +122,16 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("master + 2 managers interaction") { store = new BlockManager(actorSystem, master, serializer, 2000) - val otherStore = new BlockManager(actorSystem, master, new KryoSerializer, 2000) + store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") - assert(peers.head === otherStore.blockManagerId, "peer returned by master is not the other manager") + assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager") val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) - otherStore.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) + store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") } @@ -189,8 +189,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store invokePrivate heartBeat() - assert(master.getLocations("a1").size > 0, - "a1 was not reregistered with master") + assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") } test("reregistration on block update") { @@ -211,28 +210,6 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(master.getLocations("a2").size > 0, "master was not told about a2") } - test("deregistration on duplicate") { - val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager(actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - - assert(master.getLocations("a1").size > 0, "master was not told about a1") - - store2 = new BlockManager(actorSystem, master, serializer, 2000) - - assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - - store invokePrivate heartBeat() - - assert(master.getLocations("a1").size > 0, "master was not told about a1") - - store2 invokePrivate heartBeat() - - assert(master.getLocations("a1").size == 0, "a2 was not removed from master") - } - test("in-memory LRU storage") { store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) -- cgit v1.2.3 From 72eed2b95edb3b0b213517c815e09c3886b11669 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 17 Dec 2012 18:52:43 -0800 Subject: Converted CheckpointState in RDDCheckpointData to use scala Enumeration. --- core/src/main/scala/spark/RDDCheckpointData.scala | 48 +++++++++++------------ 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index ff2ed4cdfc..7613b338e6 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -5,45 +5,41 @@ import rdd.CoalescedRDD import scheduler.{ResultTask, ShuffleMapTask} /** - * This class contains all the information of the regarding RDD checkpointing. + * Enumeration to manage state transitions of an RDD through checkpointing + * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] */ +private[spark] object CheckpointState extends Enumeration { + type CheckpointState = Value + val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value +} +/** + * This class contains all the information of the regarding RDD checkpointing. + */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) extends Logging with Serializable { - /** - * This class manages the state transition of an RDD through checkpointing - * [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ] - */ - class CheckpointState extends Serializable { - var state = 0 + import CheckpointState._ - def mark() { if (state == 0) state = 1 } - def start() { assert(state == 1); state = 2 } - def finish() { assert(state == 2); state = 3 } - - def isMarked() = { state == 1 } - def isInProgress = { state == 2 } - def isCheckpointed = { state == 3 } - } - - val cpState = new CheckpointState() + var cpState = Initialized @transient var cpFile: Option[String] = None @transient var cpRDD: Option[RDD[T]] = None @transient var cpRDDSplits: Seq[Split] = Nil // Mark the RDD for checkpointing - def markForCheckpoint() = { - RDDCheckpointData.synchronized { cpState.mark() } + def markForCheckpoint() { + RDDCheckpointData.synchronized { + if (cpState == Initialized) cpState = MarkedForCheckpoint + } } // Is the RDD already checkpointed - def isCheckpointed() = { - RDDCheckpointData.synchronized { cpState.isCheckpointed } + def isCheckpointed(): Boolean = { + RDDCheckpointData.synchronized { cpState == Checkpointed } } - // Get the file to which this RDD was checkpointed to as a Option - def getCheckpointFile() = { + // Get the file to which this RDD was checkpointed to as an Option + def getCheckpointFile(): Option[String] = { RDDCheckpointData.synchronized { cpFile } } @@ -52,8 +48,8 @@ extends Logging with Serializable { // If it is marked for checkpointing AND checkpointing is not already in progress, // then set it to be in progress, else return RDDCheckpointData.synchronized { - if (cpState.isMarked && !cpState.isInProgress) { - cpState.start() + if (cpState == MarkedForCheckpoint) { + cpState = CheckpointingInProgress } else { return } @@ -87,7 +83,7 @@ extends Logging with Serializable { cpRDD = Some(newRDD) cpRDDSplits = newRDD.splits rdd.changeDependencies(newRDD) - cpState.finish() + cpState = Checkpointed RDDCheckpointData.checkpointCompleted() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } -- cgit v1.2.3 From bfac06e1f620efcd17beb16750dc57db6b424fb7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 17 Dec 2012 23:05:52 -0800 Subject: SPARK-616: Logging dead workers in Web UI. This patch keeps track of which workers have died and marks them as such in the master web UI. It also handles workers which die and re-register using different actor ID's. --- core/src/main/scala/spark/deploy/master/Master.scala | 7 +++++-- core/src/main/scala/spark/deploy/master/WorkerInfo.scala | 6 +++++- core/src/main/scala/spark/deploy/master/WorkerState.scala | 7 +++++++ core/src/main/twirl/spark/deploy/master/worker_row.scala.html | 1 + core/src/main/twirl/spark/deploy/master/worker_table.scala.html | 1 + 5 files changed, 19 insertions(+), 3 deletions(-) create mode 100644 core/src/main/scala/spark/deploy/master/WorkerState.scala diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index b30c8e99b5..6ecebe626a 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -156,7 +156,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor if (spreadOutJobs) { // Try to spread out each job among all the nodes, until it has all its cores for (job <- waitingJobs if job.coresLeft > 0) { - val usableWorkers = workers.toArray.filter(canUse(job, _)).sortBy(_.coresFree).reverse + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(canUse(job, _)).sortBy(_.coresFree).reverse val numUsable = usableWorkers.length val assigned = new Array[Int](numUsable) // Number of cores to give on each node var toAssign = math.min(job.coresLeft, usableWorkers.map(_.coresFree).sum) @@ -203,6 +204,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, publicAddress: String): WorkerInfo = { + // There may be one or more refs to dead workers on this same node (w/ different ID's), remove them. + workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _) val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) workers += worker idToWorker(worker.id) = worker @@ -213,7 +216,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def removeWorker(worker: WorkerInfo) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) - workers -= worker + worker.setState(WorkerState.DEAD) idToWorker -= worker.id actorToWorker -= worker.actor addressToWorker -= worker.actor.path.address diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala index a0a698ef04..5a7f5fef8a 100644 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala @@ -14,7 +14,7 @@ private[spark] class WorkerInfo( val publicAddress: String) { var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info - + var state: WorkerState.Value = WorkerState.ALIVE var coresUsed = 0 var memoryUsed = 0 @@ -42,4 +42,8 @@ private[spark] class WorkerInfo( def webUiAddress : String = { "http://" + this.publicAddress + ":" + this.webUiPort } + + def setState(state: WorkerState.Value) = { + this.state = state + } } diff --git a/core/src/main/scala/spark/deploy/master/WorkerState.scala b/core/src/main/scala/spark/deploy/master/WorkerState.scala new file mode 100644 index 0000000000..0bf35014c8 --- /dev/null +++ b/core/src/main/scala/spark/deploy/master/WorkerState.scala @@ -0,0 +1,7 @@ +package spark.deploy.master + +private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") { + type WorkerState = Value + + val ALIVE, DEAD, DECOMMISSIONED = Value +} diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html index c32ab30401..be69e9bf02 100644 --- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html +++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html @@ -7,6 +7,7 @@ @worker.id @{worker.host}:@{worker.port} + @worker.state @worker.cores (@worker.coresUsed Used) @{Utils.memoryMegabytesToString(worker.memory)} (@{Utils.memoryMegabytesToString(worker.memoryUsed)} Used) diff --git a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html index fad1af41dc..b249411a62 100644 --- a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html +++ b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html @@ -5,6 +5,7 @@ ID Address + State Cores Memory -- cgit v1.2.3 From 5184141936c18f12c6738caae6fceee4d15800e2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 18 Dec 2012 13:30:53 -0800 Subject: Introduced getSpits, getDependencies, and getPreferredLocations in RDD and RDDCheckpointData. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/ParallelCollection.scala | 9 +- core/src/main/scala/spark/RDD.scala | 123 +++++++++++++-------- core/src/main/scala/spark/RDDCheckpointData.scala | 10 +- core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/CartesianRDD.scala | 12 +- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 11 +- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 10 +- core/src/main/scala/spark/rdd/FilteredRDD.scala | 2 +- core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/GlommedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 +- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 2 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 2 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 2 +- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 4 +- core/src/main/scala/spark/rdd/PipedRDD.scala | 2 +- core/src/main/scala/spark/rdd/SampledRDD.scala | 9 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 7 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 13 +-- .../main/scala/spark/scheduler/DAGScheduler.scala | 2 +- core/src/test/scala/spark/CheckpointSuite.scala | 6 +- 22 files changed, 134 insertions(+), 113 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 1f82bd3ab8..09ac606cfb 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -628,7 +628,7 @@ private[spark] class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U) extends RDD[(K, U)](prev.get) { - override def splits = firstParent[(K, V)].splits + override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = firstParent[(K, V)].iterator(split).map{case (k, v) => (k, f(v))} } @@ -637,7 +637,7 @@ private[spark] class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) extends RDD[(K, U)](prev.get) { - override def splits = firstParent[(K, V)].splits + override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = { firstParent[(K, V)].iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9d12af6912..0bc5b2ff11 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -37,15 +37,12 @@ private[spark] class ParallelCollection[T: ClassManifest]( slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } - override def splits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_.asInstanceOf[Array[Split]] override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - - override def preferredLocations(s: Split): Seq[String] = Nil - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 6c04769c82..f3e422fa5f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -81,48 +81,33 @@ abstract class RDD[T: ClassManifest]( def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) - // Methods that must be implemented by subclasses: - - /** Set of partitions in this RDD. */ - def splits: Array[Split] + // ======================================================================= + // Methods that should be implemented by subclasses of RDD + // ======================================================================= /** Function for computing a given partition. */ def compute(split: Split): Iterator[T] - /** How this RDD depends on any parent RDDs. */ - def dependencies: List[Dependency[_]] = dependencies_ + /** Set of partitions in this RDD. */ + protected def getSplits(): Array[Split] - /** Record user function generating this RDD. */ - private[spark] val origin = Utils.getSparkCallSite - - /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None + /** How this RDD depends on any parent RDDs. */ + protected def getDependencies(): List[Dependency[_]] = dependencies_ /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = Nil - - /** The [[spark.SparkContext]] that this RDD was created on. */ - def context = sc + protected def getPreferredLocations(split: Split): Seq[String] = Nil - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - - /** A unique ID for this RDD (within its SparkContext). */ - val id = sc.newRddId() - - // Variables relating to persistence - private var storageLevel: StorageLevel = StorageLevel.NONE + /** Optionally overridden by subclasses to specify how they are partitioned. */ + val partitioner: Option[Partitioner] = None - protected[spark] var checkpointData: Option[RDDCheckpointData[T]] = None - /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassManifest] = { - dependencies.head.rdd.asInstanceOf[RDD[U]] - } - /** Returns the `i` th parent RDD */ - protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + // ======================================================================= + // Methods and fields available on all RDDs + // ======================================================================= - // Methods available on all RDDs: + /** A unique ID for this RDD (within its SparkContext). */ + val id = sc.newRddId() /** * Set this RDD's storage level to persist its values across operations after the first time @@ -147,11 +132,39 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - def getPreferredLocations(split: Split) = { + /** + * Get the preferred location of a split, taking into account whether the + * RDD is checkpointed or not. + */ + final def preferredLocations(split: Split): Seq[String] = { + if (isCheckpointed) { + checkpointData.get.getPreferredLocations(split) + } else { + getPreferredLocations(split) + } + } + + /** + * Get the array of splits of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def splits: Array[Split] = { + if (isCheckpointed) { + checkpointData.get.getSplits + } else { + getSplits + } + } + + /** + * Get the array of splits of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def dependencies: List[Dependency[_]] = { if (isCheckpointed) { - checkpointData.get.preferredLocations(split) + dependencies_ } else { - preferredLocations(split) + getDependencies } } @@ -536,6 +549,27 @@ abstract class RDD[T: ClassManifest]( if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None } + // ======================================================================= + // Other internal methods and fields + // ======================================================================= + + private var storageLevel: StorageLevel = StorageLevel.NONE + + /** Record user function generating this RDD. */ + private[spark] val origin = Utils.getSparkCallSite + + private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] + + private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + + /** Returns the first parent RDD */ + protected[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** The [[spark.SparkContext]] that this RDD was created on. */ + def context = sc + /** * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler * after a job using this RDD has completed (therefore the RDD has been materialized and @@ -548,23 +582,18 @@ abstract class RDD[T: ClassManifest]( /** * Changes the dependencies of this RDD from its original parents to the new RDD - * (`newRDD`) created from the checkpoint file. This method must ensure that all references - * to the original parent RDDs must be removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own changing - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + * (`newRDD`) created from the checkpoint file. */ protected[spark] def changeDependencies(newRDD: RDD[_]) { + clearDependencies() dependencies_ = List(new OneToOneDependency(newRDD)) } - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - oos.defaultWriteObject() - } - - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - ois.defaultReadObject() - } - + /** + * Clears the dependencies of this RDD. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected[spark] def clearDependencies() { } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 7613b338e6..e4c0912cdc 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -24,7 +24,6 @@ extends Logging with Serializable { var cpState = Initialized @transient var cpFile: Option[String] = None @transient var cpRDD: Option[RDD[T]] = None - @transient var cpRDDSplits: Seq[Split] = Nil // Mark the RDD for checkpointing def markForCheckpoint() { @@ -81,7 +80,6 @@ extends Logging with Serializable { RDDCheckpointData.synchronized { cpFile = Some(file) cpRDD = Some(newRDD) - cpRDDSplits = newRDD.splits rdd.changeDependencies(newRDD) cpState = Checkpointed RDDCheckpointData.checkpointCompleted() @@ -90,12 +88,18 @@ extends Logging with Serializable { } // Get preferred location of a split after checkpointing - def preferredLocations(split: Split) = { + def getPreferredLocations(split: Split) = { RDDCheckpointData.synchronized { cpRDD.get.preferredLocations(split) } } + def getSplits: Array[Split] = { + RDDCheckpointData.synchronized { + cpRDD.get.splits + } + } + // Get iterator. This is called at the worker nodes. def iterator(split: Split): Iterator[T] = { rdd.firstParent[T].iterator(split) diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 0c8cdd10dd..68e570eb15 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -29,7 +29,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St HashMap(blockIds.zip(locations):_*) } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[T] = { val blockManager = SparkEnv.get.blockManager @@ -41,12 +41,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = + override def getPreferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 9975e79b08..116644bd52 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -45,9 +45,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - override def splits = splits_ + override def getSplits = splits_ - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val currSplit = split.asInstanceOf[CartesianSplit] rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } @@ -66,11 +66,11 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def dependencies = deps_ + override def getDependencies = deps_ - override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + deps_ = Nil + splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index bc6d16ee8b..9cc95dc172 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -65,9 +65,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - // Pre-checkpoint dependencies deps_ should be transient (deps_) - // but post-checkpoint dependencies must not be transient (dependencies_) - override def dependencies = if (isCheckpointed) dependencies_ else deps_ + override def getDependencies = deps_ @transient var splits_ : Array[Split] = { @@ -85,7 +83,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) array } - override def splits = splits_ + override def getSplits = splits_ override val partitioner = Some(part) @@ -117,10 +115,9 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.iterator } - override def changeDependencies(newRDD: RDD[_]) { + override def clearDependencies() { deps_ = null - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 088958942e..85d0fa9f6a 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -44,7 +44,7 @@ class CoalescedRDD[T: ClassManifest]( } } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { @@ -59,11 +59,11 @@ class CoalescedRDD[T: ClassManifest]( } ) - override def dependencies = deps_ + override def getDependencies() = deps_ - override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD)) - splits_ = newRDD.splits + override def clearDependencies() { + deps_ = Nil + splits_ = null prev = null } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index 02f2e7c246..309ed2399d 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -9,6 +9,6 @@ class FilteredRDD[T: ClassManifest]( f: T => Boolean) extends RDD[T](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).filter(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index cdc8ecdcfe..1160e68bb8 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -9,6 +9,6 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( f: T => TraversableOnce[U]) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index df6f61c69d..4fab1a56fa 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -6,6 +6,6 @@ import spark.Split private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index af54f23ebc..fce190b860 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -67,7 +67,7 @@ class HadoopRDD[K, V]( .asInstanceOf[InputFormat[K, V]] } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopSplit] @@ -110,7 +110,7 @@ class HadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { // TODO: Filtering out "localhost" in case of file:// URLs val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index 23b9fb023b..5f4acee041 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -12,6 +12,6 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = f(firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 41955c1d7a..f0f3f2c7c7 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -14,6 +14,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( f: (Int, Iterator[T]) => Iterator[U]) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 6f8cb21fd3..44b542db93 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -9,6 +9,6 @@ class MappedRDD[U: ClassManifest, T: ClassManifest]( f: T => U) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).map(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index c12df5839e..91f89e3c75 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -55,7 +55,7 @@ class NewHadoopRDD[K, V]( result } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] @@ -89,7 +89,7 @@ class NewHadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d2047375ea..a88929e55e 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -29,7 +29,7 @@ class PipedRDD[T: ClassManifest]( // using a standard StringTokenizer (i.e. by spaces) def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split): Iterator[String] = { val pb = new ProcessBuilder(command) diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index c622e14a66..da6f65765c 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -26,9 +26,9 @@ class SampledRDD[T: ClassManifest]( firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } - override def splits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_.asInstanceOf[Array[Split]] - override def preferredLocations(split: Split) = + override def getPreferredLocations(split: Split) = firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) override def compute(splitIn: Split) = { @@ -51,8 +51,7 @@ class SampledRDD[T: ClassManifest]( } } - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index a9dd3f35ed..2caf33c21e 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -25,15 +25,14 @@ class ShuffledRDD[K, V]( @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index a84867492b..05ed6172d1 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -37,7 +37,7 @@ class UnionRDD[T: ClassManifest]( array } - override def splits = splits_ + override def getSplits = splits_ @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] @@ -49,19 +49,16 @@ class UnionRDD[T: ClassManifest]( deps.toList } - // Pre-checkpoint dependencies deps_ should be transient (deps_) - // but post-checkpoint dependencies must not be transient (dependencies_) - override def dependencies = if (isCheckpointed) dependencies_ else deps_ + override def getDependencies = deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = + override def getPreferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() - override def changeDependencies(newRDD: RDD[_]) { + override def clearDependencies() { deps_ = null - dependencies_ = List(new OneToOneDependency(newRDD)) - splits_ = newRDD.splits + splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 33d35b35d1..4b2570fa2b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -575,7 +575,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.getPreferredLocations(rdd.splits(partition)).toList + val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList if (rddPrefs != Nil) { return rddPrefs } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 0bffedb8db..19626d2450 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -57,7 +57,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) - assert(parCollection.splits.toList === parCollection.checkpointData.get.cpRDDSplits.toList) + assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList) assert(parCollection.collect() === result) } @@ -72,7 +72,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) - assert(blockRDD.splits.toList === blockRDD.checkpointData.get.cpRDDSplits.toList) + assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList) assert(blockRDD.collect() === result) } @@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) // Test whether the splits have been changed to the new Hadoop splits - assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.cpRDDSplits.toList) + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList) // Test whether the number of splits is same as before assert(operatedRDD.splits.length === numSplits) -- cgit v1.2.3 From 5488ac67c3ab1b91c8936fcdb421c966aa73bb6e Mon Sep 17 00:00:00 2001 From: Thomas Dudziak Date: Wed, 19 Dec 2012 10:20:43 -0800 Subject: Tweaked debian packaging to be a bit more in line with debian standards --- repl-bin/pom.xml | 5 +++-- repl-bin/src/deb/control/control | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index 72a946f3d7..0667b71cc7 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -15,7 +15,8 @@ http://spark-project.org/ - /usr/share/spark + spark-${classifier} + /usr/share/spark-${classifier} root @@ -183,7 +184,7 @@ jdeb - ${project.build.directory}/${project.artifactId}-${classifier}_${project.version}-${buildNumber}_all.deb + ${project.build.directory}/${deb.pkg.name}_${project.version}-${buildNumber}_all.deb false gzip diff --git a/repl-bin/src/deb/control/control b/repl-bin/src/deb/control/control index afadb3fbfe..a6b4471d48 100644 --- a/repl-bin/src/deb/control/control +++ b/repl-bin/src/deb/control/control @@ -1,8 +1,8 @@ -Package: [[artifactId]] +Package: [[deb.pkg.name]] Version: [[version]]-[[buildNumber]] Section: misc Priority: extra Architecture: all Maintainer: Matei Zaharia -Description: spark repl +Description: [[name]] Distribution: development -- cgit v1.2.3 From 68c52d80ecd5dd173f755bedc813fdc1a52100aa Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 19 Dec 2012 15:27:23 -0800 Subject: Moved BlockManager's IdGenerator into BlockManager object. Removed some excessive debug messages. --- core/src/main/scala/spark/storage/BlockManager.scala | 9 ++++++--- .../main/scala/spark/storage/BlockManagerMaster.scala | 4 ++-- .../scala/spark/storage/BlockManagerMasterActor.scala | 12 ------------ core/src/main/scala/spark/util/GenerationIdUtil.scala | 19 ------------------- core/src/main/scala/spark/util/IdGenerator.scala | 14 ++++++++++++++ 5 files changed, 22 insertions(+), 36 deletions(-) delete mode 100644 core/src/main/scala/spark/util/GenerationIdUtil.scala create mode 100644 core/src/main/scala/spark/util/IdGenerator.scala diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index eedf6d96e2..682ea7baff 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -19,7 +19,7 @@ import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} import spark.network._ import spark.serializer.Serializer -import spark.util.{ByteBufferInputStream, GenerationIdUtil, MetadataCleaner, TimeStampedHashMap} +import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} import sun.nio.ch.DirectBuffer @@ -91,7 +91,7 @@ class BlockManager( val host = System.getProperty("spark.hostname", Utils.localHostName()) val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), - name = "BlockManagerActor" + GenerationIdUtil.BLOCK_MANAGER.next) + name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) @volatile private var shuttingDown = false @@ -865,7 +865,7 @@ class BlockManager( blockInfo.remove(blockId) } else { // The block has already been removed; do nothing. - logWarning("Block " + blockId + " does not exist.") + logWarning("Asked to remove block " + blockId + ", which does not exist") } } @@ -951,6 +951,9 @@ class BlockManager( private[spark] object BlockManager extends Logging { + + val ID_GENERATOR = new IdGenerator + def getMaxMemoryFromSystemProperties: Long = { val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble (Runtime.getRuntime.maxMemory * memoryFraction).toLong diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index e8a1e5889f..cb582633c4 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -20,8 +20,8 @@ private[spark] class BlockManagerMaster( masterPort: Int) extends Logging { - val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "5").toInt - val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "100").toInt + val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt + val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index e3de8d8e4e..0a1be98d83 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -183,7 +183,6 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " - logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) if (blockManagerId.ip == Utils.localHostName() && !isLocal) { logInfo("Got Register Msg from master node, don't register it") @@ -200,7 +199,6 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo( blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) } - logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) sender ! true } @@ -227,7 +225,6 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { if (blockId == null) { blockManagerInfo(blockManagerId).updateLastSeenMs() - logDebug("Got in block update 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) sender ! true return } @@ -257,15 +254,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private def getLocations(blockId: String) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockId + " " - logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) if (blockInfo.containsKey(blockId)) { var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " - + Utils.getUsedTimeMs(startTimeMs)) sender ! res.toSeq } else { - logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] sender ! res } @@ -274,25 +267,20 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private def getLocationsMultipleBlockIds(blockIds: Array[String]) { def getLocations(blockId: String): Seq[BlockManagerId] = { val tmp = blockId - logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp) if (blockInfo.containsKey(blockId)) { var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq) return res.toSeq } else { - logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp) var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] return res.toSeq } } - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq) var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] for (blockId <- blockIds) { res.append(getLocations(blockId)) } - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq) sender ! res.toSeq } diff --git a/core/src/main/scala/spark/util/GenerationIdUtil.scala b/core/src/main/scala/spark/util/GenerationIdUtil.scala deleted file mode 100644 index 8a17b700b0..0000000000 --- a/core/src/main/scala/spark/util/GenerationIdUtil.scala +++ /dev/null @@ -1,19 +0,0 @@ -package spark.util - -import java.util.concurrent.atomic.AtomicInteger - -private[spark] -object GenerationIdUtil { - - val BLOCK_MANAGER = new IdGenerator - - /** - * A util used to get a unique generation ID. This is a wrapper around - * Java's AtomicInteger. - */ - class IdGenerator { - private var id = new AtomicInteger - - def next: Int = id.incrementAndGet - } -} diff --git a/core/src/main/scala/spark/util/IdGenerator.scala b/core/src/main/scala/spark/util/IdGenerator.scala new file mode 100644 index 0000000000..b6e309fe1a --- /dev/null +++ b/core/src/main/scala/spark/util/IdGenerator.scala @@ -0,0 +1,14 @@ +package spark.util + +import java.util.concurrent.atomic.AtomicInteger + +/** + * A util used to get a unique generation ID. This is a wrapper around Java's + * AtomicInteger. An example usage is in BlockManager, where each BlockManager + * instance would start an Akka actor and we use this utility to assign the Akka + * actors unique names. + */ +private[spark] class IdGenerator { + private var id = new AtomicInteger + def next: Int = id.incrementAndGet +} -- cgit v1.2.3 From 9397c5014e17a96c3cf24661c0edb40e524589e7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 20 Dec 2012 01:37:09 -0800 Subject: Let the slave notify the master block removal. --- .../main/scala/spark/storage/BlockManager.scala | 65 ++++++++++------------ .../scala/spark/storage/BlockManagerMaster.scala | 17 +----- .../spark/storage/BlockManagerMasterActor.scala | 34 ++++++----- .../scala/spark/storage/BlockManagerSuite.scala | 59 ++++++++++++-------- project/SparkBuild.scala | 2 +- 5 files changed, 84 insertions(+), 93 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 682ea7baff..7a8ac10cdd 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -59,7 +59,7 @@ class BlockManager( } } - private val blockInfo = new TimeStampedHashMap[String, BlockInfo]() + private val blockInfo = new TimeStampedHashMap[String, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -139,8 +139,8 @@ class BlockManager( */ private def reportAllBlocks() { logInfo("Reporting " + blockInfo.size + " blocks to the master.") - for (blockId <- blockInfo.keys) { - if (!tryToReportBlockStatus(blockId)) { + for ((blockId, info) <- blockInfo) { + if (!tryToReportBlockStatus(blockId, info)) { logError("Failed to report " + blockId + " to master; giving up.") return } @@ -168,8 +168,8 @@ class BlockManager( * message reflecting the current status, *not* the desired storage level in its block info. * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. */ - def reportBlockStatus(blockId: String) { - val needReregister = !tryToReportBlockStatus(blockId) + def reportBlockStatus(blockId: String, info: BlockInfo) { + val needReregister = !tryToReportBlockStatus(blockId, info) if (needReregister) { logInfo("Got told to reregister updating block " + blockId) // Reregistering will report our new block for free. @@ -179,29 +179,23 @@ class BlockManager( } /** - * Actually send a BlockUpdate message. Returns the mater's response, which will be true if the - * block was successfully recorded and false if the slave needs to re-register. + * Actually send a UpdateBlockInfo message. Returns the mater's response, + * which will be true if the block was successfully recorded and false if + * the slave needs to re-register. */ - private def tryToReportBlockStatus(blockId: String): Boolean = { - val (curLevel, inMemSize, onDiskSize, tellMaster) = blockInfo.get(blockId) match { - case None => - (StorageLevel.NONE, 0L, 0L, false) - case Some(info) => - info.synchronized { - info.level match { - case null => - (StorageLevel.NONE, 0L, 0L, false) - case level => - val inMem = level.useMemory && memoryStore.contains(blockId) - val onDisk = level.useDisk && diskStore.contains(blockId) - ( - new StorageLevel(onDisk, inMem, level.deserialized, level.replication), - if (inMem) memoryStore.getSize(blockId) else 0L, - if (onDisk) diskStore.getSize(blockId) else 0L, - info.tellMaster - ) - } - } + private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = { + val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { + info.level match { + case null => + (StorageLevel.NONE, 0L, 0L, false) + case level => + val inMem = level.useMemory && memoryStore.contains(blockId) + val onDisk = level.useDisk && diskStore.contains(blockId) + val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + val memSize = if (inMem) memoryStore.getSize(blockId) else 0L + val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L + (storageLevel, memSize, diskSize, info.tellMaster) + } } if (tellMaster) { @@ -648,7 +642,7 @@ class BlockManager( // and tell the master about it. myInfo.markReady(size) if (tellMaster) { - reportBlockStatus(blockId) + reportBlockStatus(blockId, myInfo) } } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) @@ -735,7 +729,7 @@ class BlockManager( // and tell the master about it. myInfo.markReady(bytes.limit) if (tellMaster) { - reportBlockStatus(blockId) + reportBlockStatus(blockId, myInfo) } } @@ -834,7 +828,7 @@ class BlockManager( logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") } if (info.tellMaster) { - reportBlockStatus(blockId) + reportBlockStatus(blockId, info) } if (!level.useDisk) { // The block is completely gone from this node; forget it so we can put() it again later. @@ -847,9 +841,7 @@ class BlockManager( } /** - * Remove a block from both memory and disk. This one doesn't report to the master - * because it expects the master to initiate the original block removal command, and - * then the master can update the block tracking itself. + * Remove a block from both memory and disk. */ def removeBlock(blockId: String) { logInfo("Removing block " + blockId) @@ -863,6 +855,9 @@ class BlockManager( "the disk or memory store") } blockInfo.remove(blockId) + if (info.tellMaster) { + reportBlockStatus(blockId, info) + } } else { // The block has already been removed; do nothing. logWarning("Asked to remove block " + blockId + ", which does not exist") @@ -872,7 +867,7 @@ class BlockManager( def dropOldBlocks(cleanupTime: Long) { logInfo("Dropping blocks older than " + cleanupTime) val iterator = blockInfo.internalMap.entrySet().iterator() - while(iterator.hasNext) { + while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) if (time < cleanupTime) { @@ -887,7 +882,7 @@ class BlockManager( iterator.remove() logInfo("Dropped block " + id) } - reportBlockStatus(id) + reportBlockStatus(id, info) } } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index cb582633c4..a3d8671834 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -101,7 +101,7 @@ private[spark] class BlockManagerMaster( * blocks that the master knows about. */ def removeBlock(blockId: String) { - askMaster(RemoveBlock(blockId)) + askMasterWithRetry(RemoveBlock(blockId)) } /** @@ -130,21 +130,6 @@ private[spark] class BlockManagerMaster( } } - /** - * Send a message to the master actor and get its result within a default timeout, or - * throw a SparkException if this fails. There is no retry logic here so if the Akka - * message is lost, the master actor won't get the command. - */ - private def askMaster[T](message: Any): Any = { - try { - val future = masterActor.ask(message)(timeout) - return Await.result(future, timeout).asInstanceOf[T] - } catch { - case e: Exception => - throw new SparkException("Error communicating with BlockManagerMaster", e) - } - } - /** * Send a message to the master actor and get its result within a default timeout, or * throw a SparkException if this fails. diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index 0a1be98d83..f4d026da33 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -28,7 +28,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]] // Mapping from block id to the set of block managers that have the block. - private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] + private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] initLogging() @@ -53,7 +53,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { register(blockManagerId, maxMemSize, slaveActor) case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => - blockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) + updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) case GetLocations(blockId) => getLocations(blockId) @@ -108,10 +108,10 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { var iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next - val locations = blockInfo.get(blockId)._2 + val locations = blockLocations.get(blockId)._2 locations -= blockManagerId if (locations.size == 0) { - blockInfo.remove(locations) + blockLocations.remove(locations) } } } @@ -154,7 +154,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. private def removeBlock(blockId: String) { - val block = blockInfo.get(blockId) + val block = blockLocations.get(blockId) if (block != null) { block._2.foreach { blockManagerId: BlockManagerId => val blockManager = blockManagerInfo.get(blockManagerId) @@ -163,11 +163,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. blockManager.get.slaveActor ! RemoveBlock(blockId) - // Remove the block from the master's BlockManagerInfo. - blockManager.get.updateBlockInfo(blockId, StorageLevel.NONE, 0, 0) } } - blockInfo.remove(blockId) } sender ! true } @@ -202,7 +199,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! true } - private def blockUpdate( + private def updateBlockInfo( blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, @@ -232,21 +229,22 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) var locations: HashSet[BlockManagerId] = null - if (blockInfo.containsKey(blockId)) { - locations = blockInfo.get(blockId)._2 + if (blockLocations.containsKey(blockId)) { + locations = blockLocations.get(blockId)._2 } else { locations = new HashSet[BlockManagerId] - blockInfo.put(blockId, (storageLevel.replication, locations)) + blockLocations.put(blockId, (storageLevel.replication, locations)) } if (storageLevel.isValid) { - locations += blockManagerId + locations.add(blockManagerId) } else { locations.remove(blockManagerId) } + // Remove the block from master tracking if it has been removed on all slaves. if (locations.size == 0) { - blockInfo.remove(blockId) + blockLocations.remove(blockId) } sender ! true } @@ -254,9 +252,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private def getLocations(blockId: String) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockId + " " - if (blockInfo.containsKey(blockId)) { + if (blockLocations.containsKey(blockId)) { var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) + res.appendAll(blockLocations.get(blockId)._2) sender ! res.toSeq } else { var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] @@ -267,9 +265,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private def getLocationsMultipleBlockIds(blockIds: Array[String]) { def getLocations(blockId: String): Seq[BlockManagerId] = { val tmp = blockId - if (blockInfo.containsKey(blockId)) { + if (blockLocations.containsKey(blockId)) { var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) + res.appendAll(blockLocations.get(blockId)._2) return res.toSeq } else { var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 4e28a7e2bc..8f86e3170e 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -7,6 +7,10 @@ import akka.actor._ import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.matchers.ShouldMatchers._ +import org.scalatest.time.SpanSugar._ import spark.KryoSerializer import spark.SizeEstimator @@ -142,37 +146,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) - // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false) + // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 + store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, false) // Checking whether blocks are in memory and memory size - var memStatus = master.getMemoryStatus.head._2 + val memStatus = master.getMemoryStatus.head._2 assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200") - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.getSingle("a1-to-remove") != None, "a1 was not in store") + assert(store.getSingle("a2-to-remove") != None, "a2 was not in store") + assert(store.getSingle("a3-to-remove") != None, "a3 was not in store") // Checking whether master knows about the blocks or not - assert(master.getLocations("a1").size > 0, "master was not told about a1") - assert(master.getLocations("a2").size > 0, "master was not told about a2") - assert(master.getLocations("a3").size === 0, "master was told about a3") + assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") + assert(master.getLocations("a2-to-remove").size > 0, "master was not told about a2") + assert(master.getLocations("a3-to-remove").size === 0, "master was told about a3") // Remove a1 and a2 and a3. Should be no-op for a3. - master.removeBlock("a1") - master.removeBlock("a2") - master.removeBlock("a3") - assert(store.getSingle("a1") === None, "a1 not removed from store") - assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(master.getLocations("a1").size === 0, "master did not remove a1") - assert(master.getLocations("a2").size === 0, "master did not remove a2") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(master.getLocations("a3").size === 0, "master was told about a3") - memStatus = master.getMemoryStatus.head._2 - assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") - assert(memStatus._2 == 2000L, "remaining memory " + memStatus._1 + " should equal 2000") + master.removeBlock("a1-to-remove") + master.removeBlock("a2-to-remove") + master.removeBlock("a3-to-remove") + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a1-to-remove") should be (None) + master.getLocations("a1-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a2-to-remove") should be (None) + master.getLocations("a2-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a3-to-remove") should not be (None) + master.getLocations("a3-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + val memStatus = master.getMemoryStatus.head._2 + memStatus._1 should equal (2000L) + memStatus._2 should equal (2000L) + } } test("reregistration on heart beat") { diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2f67bb9921..34b93fb694 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -87,7 +87,7 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011", - "org.scalatest" %% "scalatest" % "1.6.1" % "test", + "org.scalatest" %% "scalatest" % "1.8" % "test", "org.scalacheck" %% "scalacheck" % "1.9" % "test", "com.novocode" % "junit-interface" % "0.8" % "test" ), -- cgit v1.2.3 From f9c5b0a6fe8d728e16c60c0cf51ced0054e3a387 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 20 Dec 2012 11:52:23 -0800 Subject: Changed checkpoint writing and reading process. --- core/src/main/scala/spark/RDDCheckpointData.scala | 27 +---- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 117 ++++++++++++++++++++++ core/src/main/scala/spark/rdd/HadoopRDD.scala | 5 +- 3 files changed, 124 insertions(+), 25 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/CheckpointRDD.scala diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index e4c0912cdc..1aa9b9aa1e 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -1,7 +1,7 @@ package spark import org.apache.hadoop.fs.Path -import rdd.CoalescedRDD +import rdd.{CheckpointRDD, CoalescedRDD} import scheduler.{ResultTask, ShuffleMapTask} /** @@ -55,30 +55,13 @@ extends Logging with Serializable { } // Save to file, and reload it as an RDD - val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString - rdd.saveAsObjectFile(file) - - val newRDD = { - val hadoopRDD = rdd.context.objectFile[T](file, rdd.splits.size) - - val oldSplits = rdd.splits.size - val newSplits = hadoopRDD.splits.size - - logDebug("RDD splits = " + oldSplits + " --> " + newSplits) - if (newSplits < oldSplits) { - throw new Exception("# splits after checkpointing is less than before " + - "[" + oldSplits + " --> " + newSplits) - } else if (newSplits > oldSplits) { - new CoalescedRDD(hadoopRDD, rdd.splits.size) - } else { - hadoopRDD - } - } - logDebug("New RDD has " + newRDD.splits.size + " splits") + val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) + val newRDD = new CheckpointRDD[T](rdd.context, path) // Change the dependencies and splits of the RDD RDDCheckpointData.synchronized { - cpFile = Some(file) + cpFile = Some(path) cpRDD = Some(newRDD) rdd.changeDependencies(newRDD) cpState = Checkpointed diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala new file mode 100644 index 0000000000..c673ab6aaa --- /dev/null +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -0,0 +1,117 @@ +package spark.rdd + +import spark._ +import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.{NullWritable, BytesWritable} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.fs.Path +import java.io.{File, IOException, EOFException} +import java.text.NumberFormat + +private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split { + override val index: Int = idx +} + +class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) + extends RDD[T](sc, Nil) { + + @transient val path = new Path(checkpointPath) + @transient val fs = path.getFileSystem(new Configuration()) + + @transient val splits_ : Array[Split] = { + val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted + splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray + } + + override def getSplits = splits_ + + override def getPreferredLocations(split: Split): Seq[String] = { + val status = fs.getFileStatus(path) + val locations = fs.getFileBlockLocations(status, 0, status.getLen) + locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + } + + override def compute(split: Split): Iterator[T] = { + CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile) + } + + override def checkpoint() { + // Do nothing. Hadoop RDD should not be checkpointed. + } +} + +private[spark] object CheckpointRDD extends Logging { + + def splitIdToFileName(splitId: Int): String = { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + "part-" + numfmt.format(splitId) + } + + def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) { + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(new Configuration()) + + val finalOutputName = splitIdToFileName(context.splitId) + val finalOutputPath = new Path(outputDir, finalOutputName) + val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId) + + if (fs.exists(tempOutputPath)) { + throw new IOException("Checkpoint failed: temporary path " + + tempOutputPath + " already exists") + } + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + serializeStream.writeAll(iterator) + fileOutputStream.close() + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.delete(finalOutputPath, true)) { + throw new IOException("Checkpoint failed: failed to delete earlier output of task " + + context.attemptId); + } + if (!fs.rename(tempOutputPath, finalOutputPath)) { + throw new IOException("Checkpoint failed: failed to save output of task: " + + context.attemptId) + } + } + } + + def readFromFile[T](path: String): Iterator[T] = { + val inputPath = new Path(path) + val fs = inputPath.getFileSystem(new Configuration()) + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val fileInputStream = fs.open(inputPath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + deserializeStream.asIterator.asInstanceOf[Iterator[T]] + } + + // Test whether CheckpointRDD generate expected number of splits despite + // each split file having multiple blocks. This needs to be run on a + // cluster (mesos or standalone) using HDFS. + def main(args: Array[String]) { + import spark._ + + val Array(cluster, hdfsPath) = args + val sc = new SparkContext(cluster, "CheckpointRDD Test") + val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) + val path = new Path(hdfsPath, "temp") + val fs = path.getFileSystem(new Configuration()) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _) + val cpRDD = new CheckpointRDD[Int](sc, path.toString) + assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same") + assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same") + fs.delete(path) + } +} diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index fce190b860..eca51758e4 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -25,8 +25,7 @@ import spark.Split * A Spark split class that wraps around a Hadoop InputSplit. */ private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) - extends Split - with Serializable { + extends Split { val inputSplit = new SerializableWritable[InputSplit](s) @@ -117,6 +116,6 @@ class HadoopRDD[K, V]( } override def checkpoint() { - // Do nothing. Hadoop RDD cannot be checkpointed. + // Do nothing. Hadoop RDD should not be checkpointed. } } -- cgit v1.2.3 From fe777eb77dee3c5bc5a7a332098d27f517ad3fe4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 20 Dec 2012 13:39:27 -0800 Subject: Fixed bugs in CheckpointRDD and spark.CheckpointSuite. --- core/src/main/scala/spark/SparkContext.scala | 12 +++--------- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 3 +++ core/src/test/scala/spark/CheckpointSuite.scala | 6 +++--- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 71ed4ef058..362aa04e66 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -37,9 +37,7 @@ import spark.broadcast._ import spark.deploy.LocalSparkCluster import spark.partial.ApproximateEvaluator import spark.partial.PartialResult -import spark.rdd.HadoopRDD -import spark.rdd.NewHadoopRDD -import spark.rdd.UnionRDD +import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD} import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} @@ -368,13 +366,9 @@ class SparkContext( protected[spark] def checkpointFile[T: ClassManifest]( - path: String, - minSplits: Int = defaultMinSplits + path: String ): RDD[T] = { - val rdd = objectFile[T](path, minSplits) - rdd.checkpointData = Some(new RDDCheckpointData(rdd)) - rdd.checkpointData.get.cpFile = Some(path) - rdd + new CheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index c673ab6aaa..fbf8a9ef83 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -24,6 +24,9 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray } + checkpointData = Some(new RDDCheckpointData[T](this)) + checkpointData.get.cpFile = Some(checkpointPath) + override def getSplits = splits_ override def getPreferredLocations(split: Split): Seq[String] = { diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 19626d2450..6bc667bd4c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -54,7 +54,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() - assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList) @@ -69,7 +69,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val numSplits = blockRDD.splits.size blockRDD.checkpoint() val result = blockRDD.collect() - assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList) @@ -185,7 +185,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) // Test whether the checkpoint file has been created - assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) -- cgit v1.2.3 From 60f7338092ad0c3a608c0e466f66047a508a35be Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 21 Dec 2012 15:49:33 -0800 Subject: Remove the call to close input stream in Kryo serializer. --- core/src/main/scala/spark/KryoSerializer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index f24196ea49..93d7327324 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -46,8 +46,8 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } def close() { + // Kryo's Input automatically closes the input stream it is using. input.close() - inStream.close() } } -- cgit v1.2.3 From c68a0760379ff8d8a1ae194934ae54d19f1eb213 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 21 Dec 2012 16:03:17 -0800 Subject: Updated Kryo documentation for Kryo version update. --- docs/tuning.md | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/docs/tuning.md b/docs/tuning.md index f18de8ff3a..9aaa53cd65 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -33,7 +33,7 @@ in your operations) and performance. It provides two serialization libraries: Java serialization is flexible but often quite slow, and leads to large serialized formats for many classes. * [Kryo serialization](http://code.google.com/p/kryo/wiki/V1Documentation): Spark can also use - the Kryo library (currently just version 1) to serialize objects more quickly. Kryo is significantly + the Kryo library (version 2) to serialize objects more quickly. Kryo is significantly faster and more compact than Java serialization (often as much as 10x), but does not support all `Serializable` types and requires you to *register* the classes you'll use in the program in advance for best performance. @@ -47,6 +47,8 @@ Finally, to register your classes with Kryo, create a public class that extends `spark.kryo.registrator` system property to point to it, as follows: {% highlight scala %} +import com.esotericsoftware.kryo.Kryo + class MyRegistrator extends KryoRegistrator { override def registerClasses(kryo: Kryo) { kryo.register(classOf[MyClass1]) @@ -60,7 +62,7 @@ System.setProperty("spark.kryo.registrator", "mypackage.MyRegistrator") val sc = new SparkContext(...) {% endhighlight %} -The [Kryo documentation](http://code.google.com/p/kryo/wiki/V1Documentation) describes more advanced +The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced registration options, such as adding custom serialization code. If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` @@ -147,7 +149,7 @@ the space allocated to the RDD cache to mitigate this. **Measuring the Impact of GC** -The first step in GC tuning is to collect statistics on how frequently garbage collection occurs and the amount of +The first step in GC tuning is to collect statistics on how frequently garbage collection occurs and the amount of time spent GC. This can be done by adding `-verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps` to your `SPARK_JAVA_OPTS` environment variable. Next time your Spark job is run, you will see messages printed in the worker's logs each time a garbage collection occurs. Note these logs will be on your cluster's worker nodes (in the `stdout` files in @@ -155,15 +157,15 @@ their work directories), *not* on your driver program. **Cache Size Tuning** -One important configuration parameter for GC is the amount of memory that should be used for -caching RDDs. By default, Spark uses 66% of the configured memory (`SPARK_MEM`) to cache RDDs. This means that +One important configuration parameter for GC is the amount of memory that should be used for +caching RDDs. By default, Spark uses 66% of the configured memory (`SPARK_MEM`) to cache RDDs. This means that 33% of memory is available for any objects created during task execution. In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of -memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call -`System.setProperty("spark.storage.memoryFraction", "0.5")`. Combined with the use of serialized caching, -using a smaller cache should be sufficient to mitigate most of the garbage collection problems. -In case you are interested in further tuning the Java GC, continue reading below. +memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call +`System.setProperty("spark.storage.memoryFraction", "0.5")`. Combined with the use of serialized caching, +using a smaller cache should be sufficient to mitigate most of the garbage collection problems. +In case you are interested in further tuning the Java GC, continue reading below. **Advanced GC Tuning** @@ -172,9 +174,9 @@ To further tune garbage collection, we first need to understand some basic infor * Java Heap space is divided in to two regions Young and Old. The Young generation is meant to hold short-lived objects while the Old generation is intended for objects with longer lifetimes. -* The Young generation is further divided into three regions [Eden, Survivor1, Survivor2]. +* The Young generation is further divided into three regions [Eden, Survivor1, Survivor2]. -* A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects +* A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old enough or Survivor2 is full, it is moved to Old. Finally when Old is close to full, a full GC is invoked. @@ -186,7 +188,7 @@ temporary objects created during task execution. Some steps which may be useful before a task completes, it means that there isn't enough memory available for executing tasks. * In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of memory used for caching. - This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow + This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow down task execution! * If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You @@ -195,8 +197,8 @@ temporary objects created during task execution. Some steps which may be useful up by 4/3 is to account for space used by survivor regions as well.) * As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using - the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the - size of the block. So if we wish to have 3 or 4 tasks worth of working space, and the HDFS block size is 64 MB, + the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the + size of the block. So if we wish to have 3 or 4 tasks worth of working space, and the HDFS block size is 64 MB, we can estimate size of Eden to be `4*3*64MB`. * Monitor how the frequency and time taken by garbage collection changes with the new settings. -- cgit v1.2.3 From a6bb41c6d389f1b98d5542000a7a9705ba282273 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 21 Dec 2012 16:25:50 -0800 Subject: Updated Kryo version for Maven pom file. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 52a4e9d932..b33cee26b8 100644 --- a/pom.xml +++ b/pom.xml @@ -185,7 +185,7 @@ de.javakaffee kryo-serializers - 0.9 + 0.20 com.typesafe.akka -- cgit v1.2.3 From 9ac4cb1c5fd7637ff9936f1ef54fa27f6f6aa214 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 21 Dec 2012 17:11:39 -0800 Subject: Adding a Twitter InputDStream with an example --- project/SparkBuild.scala | 5 +- .../spark/streaming/TwitterInputDStream.scala | 59 ++++++++++++++++++++++ .../spark/streaming/examples/TwitterBasic.scala | 37 ++++++++++++++ 3 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 6ef2ac477a..618c7afc37 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -115,7 +115,8 @@ object SparkBuild extends Build { "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/", "JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/", "Spray Repository" at "http://repo.spray.cc/", - "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/" + "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/", + "Twitter4J Repository" at "http://twitter4j.org/maven2/" ), libraryDependencies ++= Seq( @@ -133,6 +134,8 @@ object SparkBuild extends Build { "com.typesafe.akka" % "akka-slf4j" % "2.0.3", "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", + "org.twitter4j" % "twitter4j-core" % "3.0.2", + "org.twitter4j" % "twitter4j-stream" % "3.0.2", "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "org.apache.mesos" % "mesos" % "0.9.0-incubating" diff --git a/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala new file mode 100644 index 0000000000..5d177e96de --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala @@ -0,0 +1,59 @@ +package spark.streaming + +import spark.RDD +import spark.streaming.{Time, InputDStream} +import twitter4j._ +import twitter4j.auth.BasicAuthorization +import collection.mutable.ArrayBuffer +import collection.JavaConversions._ + +/* A stream of Twitter statuses, potentially filtered by one or more keywords. +* +* @constructor create a new Twitter stream using the supplied username and password to authenticate. +* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is +* such that this may return a sampled subset of all tweets during each interval. +*/ +class TwitterInputDStream( + @transient ssc_ : StreamingContext, + username: String, + password: String, + filters: Seq[String] + ) extends InputDStream[Status](ssc_) { + val statuses: ArrayBuffer[Status] = ArrayBuffer() + var twitterStream: TwitterStream = _ + + override def start() = { + twitterStream = new TwitterStreamFactory() + .getInstance(new BasicAuthorization(username, password)) + twitterStream.addListener(new StatusListener { + def onStatus(status: Status) = { + statuses += status + } + // Unimplemented + def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {} + def onTrackLimitationNotice(i: Int) {} + def onScrubGeo(l: Long, l1: Long) {} + def onStallWarning(stallWarning: StallWarning) {} + def onException(e: Exception) {} + }) + + val query: FilterQuery = new FilterQuery + if (filters.size > 0) { + query.track(filters.toArray) + twitterStream.filter(query) + } else { + twitterStream.sample() + } + } + + override def stop() = { + twitterStream.shutdown() + } + + override def compute(validTime: Time): Option[RDD[Status]] = { + // Flush the current tweet buffer + val rdd = Some(ssc.sc.parallelize(statuses)) + statuses.foreach(x => statuses -= x) + rdd + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala b/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala new file mode 100644 index 0000000000..c7e380fbe1 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala @@ -0,0 +1,37 @@ +package spark.streaming.examples + +import spark.streaming.StreamingContext._ +import spark.streaming.{TwitterInputDStream, Seconds, StreamingContext} + +object TwitterBasic { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println("Usage: TwitterBasic ") + System.exit(1) + } + + val Array(master, username, password) = args + + val ssc = new StreamingContext(master, "TwitterBasic", Seconds(2)) + val stream = new TwitterInputDStream(ssc, username, password, Seq()) + ssc.graph.addInputStream(stream) + + val hashTags = stream.flatMap( + status => status.getText.split(" ").filter(_.startsWith("#"))) + + // Word count over hashtags + val counts = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) + + // TODO: Sorts on one node - should do with global sorting once streaming supports it + val topCounts = counts.collect().map(_.sortBy(-_._2).take(5)) + + topCounts.foreachRDD(rdd => { + val topList = rdd.take(1)(0) + println("\nPopular topics in last 60 seconds:") + topList.foreach(t => println("%s (%s tweets)".format(t._1, t._2))) + } + ) + + ssc.start() + } +} -- cgit v1.2.3 From bce84ceabb6e7be92568bc4933410dd095b7936c Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 21 Dec 2012 20:57:46 -0800 Subject: Minor changes after review and general cleanup. - Added filters to Twitter example - Removed un-used import - Some code clean-up --- .../spark/streaming/TwitterInputDStream.scala | 1 - .../spark/streaming/examples/TwitterBasic.scala | 29 ++++++++++++++-------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala index 5d177e96de..adf1ed15c9 100644 --- a/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala @@ -1,7 +1,6 @@ package spark.streaming import spark.RDD -import spark.streaming.{Time, InputDStream} import twitter4j._ import twitter4j.auth.BasicAuthorization import collection.mutable.ArrayBuffer diff --git a/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala b/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala index c7e380fbe1..19b3cad6ad 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala @@ -5,15 +5,17 @@ import spark.streaming.{TwitterInputDStream, Seconds, StreamingContext} object TwitterBasic { def main(args: Array[String]) { - if (args.length != 3) { - System.err.println("Usage: TwitterBasic ") + if (args.length < 3) { + System.err.println("Usage: TwitterBasic " + + " [filter1] [filter2] ... [filter n]") System.exit(1) } - val Array(master, username, password) = args + val Array(master, username, password) = args.slice(0, 3) + val filters = args.slice(3, args.length) val ssc = new StreamingContext(master, "TwitterBasic", Seconds(2)) - val stream = new TwitterInputDStream(ssc, username, password, Seq()) + val stream = new TwitterInputDStream(ssc, username, password, filters) ssc.graph.addInputStream(stream) val hashTags = stream.flatMap( @@ -21,17 +23,24 @@ object TwitterBasic { // Word count over hashtags val counts = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) - // TODO: Sorts on one node - should do with global sorting once streaming supports it val topCounts = counts.collect().map(_.sortBy(-_._2).take(5)) + // Print popular hashtags topCounts.foreachRDD(rdd => { - val topList = rdd.take(1)(0) - println("\nPopular topics in last 60 seconds:") - topList.foreach(t => println("%s (%s tweets)".format(t._1, t._2))) - } + if (rdd.count() != 0) { + val topList = rdd.take(1)(0) + println("\nPopular topics in last 60 seconds:") + topList.foreach{case (tag, count) => println("%s (%s tweets)".format(tag, count))} + } + }) + + // Print number of tweets in the window + stream.window(Seconds(60)).count().foreachRDD(rdd => + if (rdd.count() != 0) { + println("Window size: %s tweets".format(rdd.take(1)(0))) + } ) - ssc.start() } } -- cgit v1.2.3 From 61be8566e24c664442780154debfea884d81f46b Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Mon, 24 Dec 2012 02:26:11 -0800 Subject: Allow distinct() to be called without parentheses when using the default number of splits. --- core/src/main/scala/spark/RDD.scala | 4 +++- core/src/test/scala/spark/RDDSuite.scala | 12 ++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index bb4c13c494..d15c6f7396 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -185,9 +185,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(numSplits: Int = splits.size): RDD[T] = + def distinct(numSplits: Int): RDD[T] = map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1) + def distinct(): RDD[T] = distinct(splits.size) + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index b3c820ed94..08da9a1c4d 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -8,9 +8,9 @@ import spark.rdd.CoalescedRDD import SparkContext._ class RDDSuite extends FunSuite with BeforeAndAfter { - + var sc: SparkContext = _ - + after { if (sc != null) { sc.stop() @@ -19,11 +19,15 @@ class RDDSuite extends FunSuite with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port") } - + test("basic operations") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) + val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) + assert(dups.distinct.count === 4) + assert(dups.distinct().collect === dups.distinct.collect) + assert(dups.distinct(2).collect === dups.distinct.collect) assert(nums.reduce(_ + _) === 10) assert(nums.fold(0)(_ + _) === 10) assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) @@ -121,7 +125,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { val zipped = nums.zip(nums.map(_ + 1.0)) assert(zipped.glom().map(_.toList).collect().toList === List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) - + intercept[IllegalArgumentException] { nums.zip(sc.parallelize(1 to 4, 1)).collect() } -- cgit v1.2.3 From 903f3518dfcd686cda2256b07fbc1dde6aec0178 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Mon, 24 Dec 2012 13:18:45 -0800 Subject: fall back to filter-map-collect when calling lookup() on an RDD without a partitioner --- core/src/main/scala/spark/PairRDDFunctions.scala | 2 +- core/src/test/scala/spark/JavaAPISuite.java | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 08ae06e865..d3e206b353 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -438,7 +438,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( val res = self.context.runJob(self, process _, Array(index), false) res(0) case None => - throw new UnsupportedOperationException("lookup() called on an RDD without a partitioner") + self.filter(_._1 == key).map(_._2).collect } } diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 46a0b68f89..33d5fc2d89 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -130,6 +130,17 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(2, foreachCalls); } + @Test + public void lookup() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + Assert.assertEquals(2, categories.lookup("Oranges").size()); + Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size()); + } + @Test public void groupBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); -- cgit v1.2.3 From ccd075cf960df6c6c449b709515cdd81499a52be Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Dec 2012 15:01:13 -0800 Subject: Reduce object overhead in Pyspark shuffle and collect --- pyspark/pyspark/rdd.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 85a24c6854..708ea6eb55 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -145,8 +145,10 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) - return load_pickle(bytes(pickle)) + def asList(iterator): + yield list(iterator) + pickles = self.mapPartitions(asList)._jrdd.rdd().collect() + return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles)) def reduce(self, f): """ @@ -319,16 +321,23 @@ class RDD(object): if numSplits is None: numSplits = self.ctx.defaultParallelism def add_shuffle_key(iterator): + buckets = defaultdict(list) for (k, v) in iterator: - yield str(hashFunc(k)) - yield dump_pickle((k, v)) + buckets[hashFunc(k) % numSplits].append((k, v)) + for (split, items) in buckets.iteritems(): + yield str(split) + yield dump_pickle(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + # Flatten the resulting RDD: + return RDD(jrdd, self.ctx).flatMap(lambda items: items) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): -- cgit v1.2.3 From 4608902fb87af64a15b97ab21fe6382cd6e5a644 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Dec 2012 17:20:10 -0800 Subject: Use filesystem to collect RDDs in PySpark. Passing large volumes of data through Py4J seems to be slow. It appears to be faster to write the data to the local filesystem and read it back from Python. --- .../main/scala/spark/api/python/PythonRDD.scala | 66 ++++++++-------------- pyspark/pyspark/context.py | 9 ++- pyspark/pyspark/rdd.py | 34 +++++++++-- pyspark/pyspark/serializers.py | 8 +++ pyspark/pyspark/worker.py | 12 +--- 5 files changed, 66 insertions(+), 63 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 50094d6b0f..4f870e837a 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,6 +1,7 @@ package spark.api.python import java.io._ +import java.util.{List => JList} import scala.collection.Map import scala.collection.JavaConversions._ @@ -59,36 +60,7 @@ trait PythonRDDBase { } out.flush() for (elem <- parent.iterator(split)) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { - val t = elem.asInstanceOf[scala.Tuple2[_, _]] - val t1 = t._1.asInstanceOf[Array[Byte]] - val t2 = t._2.asInstanceOf[Array[Byte]] - val length = t1.length + t2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t1)) - dOut.write(PythonRDD.stripPickle(t2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.writeByte(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new Exception("Unexpected RDD type") - } + PythonRDD.writeAsPickle(elem, dOut) } dOut.flush() out.flush() @@ -174,36 +146,45 @@ object PythonRDD { arr.slice(2, arr.length - 1) } - def asPickle(elem: Any) : Array[Byte] = { - val baos = new ByteArrayOutputStream(); - val dOut = new DataOutputStream(baos); + /** + * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. + * The data format is a 32-bit integer representing the pickled object's length (in bytes), + * followed by the pickled data. + * @param elem the object to write + * @param dOut a data output stream + */ + def writeAsPickle(elem: Any, dOut: DataOutputStream) { if (elem.isInstanceOf[Array[Byte]]) { - elem.asInstanceOf[Array[Byte]] + val arr = elem.asInstanceOf[Array[Byte]] + dOut.writeInt(arr.length) + dOut.write(arr) } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] + val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes + dOut.writeInt(length) dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) dOut.write(PythonRDD.stripPickle(t._1)) dOut.write(PythonRDD.stripPickle(t._2)) dOut.writeByte(Pickle.TUPLE2) dOut.writeByte(Pickle.STOP) - baos.toByteArray() } else if (elem.isInstanceOf[String]) { // For uniformity, strings are wrapped into Pickles. val s = elem.asInstanceOf[String].getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) dOut.write(Pickle.BINUNICODE) dOut.writeInt(Integer.reverseBytes(s.length)) dOut.write(s) dOut.writeByte(Pickle.STOP) - baos.toByteArray() } else { throw new Exception("Unexpected RDD type") } } - def pickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] @@ -221,11 +202,12 @@ object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def arrayAsPickle(arr : Any) : Array[Byte] = { - val pickles : Array[Byte] = arr.asInstanceOf[Array[Any]].map(asPickle).map(stripPickle).flatten - - Array[Byte](Pickle.PROTO, Pickle.TWO, Pickle.EMPTY_LIST, Pickle.MARK) ++ pickles ++ - Array[Byte] (Pickle.APPENDS, Pickle.STOP) + def writeArrayToPickleFile[T](items: Array[T], filename: String) { + val file = new DataOutputStream(new FileOutputStream(filename)) + for (item <- items) { + writeAsPickle(item, file) + } + file.close() } } diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 50d57e5317..19f9f9e133 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -14,9 +14,8 @@ class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - pickleFile = jvm.spark.api.python.PythonRDD.pickleFile - asPickle = jvm.spark.api.python.PythonRDD.asPickle - arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle + readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile + writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile def __init__(self, master, name, defaultParallelism=None): self.master = master @@ -45,11 +44,11 @@ class SparkContext(object): # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) + atexit.register(lambda: os.unlink(tempFile.name)) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - atexit.register(lambda: os.unlink(tempFile.name)) - jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) + jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 708ea6eb55..01908cff96 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,13 +1,15 @@ +import atexit from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap import os import shlex from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import dump_pickle, load_pickle +from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -145,10 +147,30 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): + # To minimize the number of transfers between Python and Java, we'll + # flatten each partition into a list before collecting it. Due to + # pipelining, this should add minimal overhead. def asList(iterator): yield list(iterator) - pickles = self.mapPartitions(asList)._jrdd.rdd().collect() - return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles)) + picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect() + return list(chain.from_iterable(self._collect_array_through_file(picklesInJava))) + + def _collect_array_through_file(self, array): + # Transferring lots of data through Py4J can be slow because + # socket.readline() is inefficient. Instead, we'll dump the data to a + # file and read it back. + tempFile = NamedTemporaryFile(delete=False) + tempFile.close() + def clean_up_file(): + try: os.unlink(tempFile.name) + except: pass + atexit.register(clean_up_file) + self.ctx.writeArrayToPickleFile(array, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + for item in read_from_pickle_file(tempFile): + yield item + os.unlink(tempFile.name) def reduce(self, f): """ @@ -220,15 +242,15 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).take(2) [2, 3] """ - pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) - return load_pickle(bytes(pickle)) + picklesInJava = self._jrdd.rdd().take(num) + return list(self._collect_array_through_file(picklesInJava)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) + return self.take(1)[0] def saveAsTextFile(self, path): def func(iterator): diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 21ef8b106c..bfcdda8f12 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -33,3 +33,11 @@ def read_with_length(stream): if obj == "": raise EOFError return obj + + +def read_from_pickle_file(stream): + try: + while True: + yield load_pickle(read_with_length(stream)) + except EOFError: + return diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 62824a1c9b..9f6b507dbd 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -8,7 +8,7 @@ from base64 import standard_b64decode from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import write_with_length, read_with_length, \ - read_long, read_int, dump_pickle, load_pickle + read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file # Redirect stdout to stderr so that users must return values from functions. @@ -20,14 +20,6 @@ def load_obj(): return load_pickle(standard_b64decode(sys.stdin.readline().strip())) -def read_input(): - try: - while True: - yield load_pickle(read_with_length(sys.stdin)) - except EOFError: - return - - def main(): num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): @@ -40,7 +32,7 @@ def main(): dumps = lambda x: x else: dumps = dump_pickle - for obj in func(read_input()): + for obj in func(read_from_pickle_file(sys.stdin)): write_with_length(dumps(obj), old_stdout) -- cgit v1.2.3 From e2dad15621f5dc15275b300df05483afde5025a0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 26 Dec 2012 17:34:24 -0800 Subject: Add support for batched serialization of Python objects in PySpark. --- pyspark/pyspark/context.py | 3 ++- pyspark/pyspark/rdd.py | 57 +++++++++++++++++++++++++++++------------- pyspark/pyspark/serializers.py | 34 ++++++++++++++++++++++++- 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 19f9f9e133..032619693a 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -17,13 +17,14 @@ class SparkContext(object): readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile - def __init__(self, master, name, defaultParallelism=None): + def __init__(self, master, name, defaultParallelism=None, batchSize=-1): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') + self.batchSize = batchSize # -1 represents a unlimited batch size # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 01908cff96..d7081dffd2 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -2,6 +2,7 @@ import atexit from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap +import operator import os import shlex from subprocess import Popen, PIPE @@ -9,7 +10,8 @@ from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file +from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ + read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -83,6 +85,11 @@ class RDD(object): >>> rdd = sc.parallelize([1, 1, 2, 3]) >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] + + # Union of batched and unbatched RDDs: + >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])]) + >>> rdd.union(batchedRDD).collect() + [1, 1, 2, 3, 1, 2, 3, 4, 5] """ return RDD(self._jrdd.union(other._jrdd), self.ctx) @@ -147,13 +154,8 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - # To minimize the number of transfers between Python and Java, we'll - # flatten each partition into a list before collecting it. Due to - # pipelining, this should add minimal overhead. - def asList(iterator): - yield list(iterator) - picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect() - return list(chain.from_iterable(self._collect_array_through_file(picklesInJava))) + picklesInJava = self._jrdd.rdd().collect() + return list(self._collect_array_through_file(picklesInJava)) def _collect_array_through_file(self, array): # Transferring lots of data through Py4J can be slow because @@ -214,12 +216,21 @@ class RDD(object): # TODO: aggregate + def sum(self): + """ + >>> sc.parallelize([1.0, 2.0, 3.0]).sum() + 6.0 + """ + return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + def count(self): """ >>> sc.parallelize([2, 3, 4]).count() - 3L + 3 + >>> sc.parallelize([Batch([2, 3, 4])]).count() + 3 """ - return self._jrdd.count() + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() def countByValue(self): """ @@ -342,24 +353,23 @@ class RDD(object): """ if numSplits is None: numSplits = self.ctx.defaultParallelism + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. def add_shuffle_key(iterator): buckets = defaultdict(list) for (k, v) in iterator: buckets[hashFunc(k) % numSplits].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) - yield dump_pickle(items) + yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - # Transferring O(n) objects to Java is too expensive. Instead, we'll - # form the hash buckets in Python, transferring O(numSplits) objects - # to Java. Each object is a (splitNumber, [objects]) pair. jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - # Flatten the resulting RDD: - return RDD(jrdd, self.ctx).flatMap(lambda items: items) + return RDD(jrdd, self.ctx) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): @@ -478,8 +488,19 @@ class PipelinedRDD(RDD): def _jrdd(self): if self._jrdd_val: return self._jrdd_val - funcs = [self.func, self._bypass_serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs) + func = self.func + if not self._bypass_serializer and self.ctx.batchSize != 1: + oldfunc = self.func + batchSize = self.ctx.batchSize + if batchSize == -1: # unlimited batch size + def batched_func(iterator): + yield Batch(list(oldfunc(iterator))) + else: + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) + func = batched_func + cmds = [func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx.gateway._gateway_client) diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index bfcdda8f12..4ed925697c 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -2,6 +2,33 @@ import struct import cPickle +class Batch(object): + """ + Used to store multiple RDD entries as a single Java object. + + This relieves us from having to explicitly track whether an RDD + is stored as batches of objects and avoids problems when processing + the union() of batched and unbatched RDDs (e.g. the union() of textFile() + with another RDD). + """ + def __init__(self, items): + self.items = items + + +def batched(iterator, batchSize): + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = [] + if items: + yield Batch(items) + + def dump_pickle(obj): return cPickle.dumps(obj, 2) @@ -38,6 +65,11 @@ def read_with_length(stream): def read_from_pickle_file(stream): try: while True: - yield load_pickle(read_with_length(stream)) + obj = load_pickle(read_with_length(stream)) + if type(obj) == Batch: # We don't care about inheritance + for item in obj.items: + yield item + else: + yield obj except EOFError: return -- cgit v1.2.3 From 1dca0c51804b9c94709ec9cc0544b8dfb7afe59f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 26 Dec 2012 18:23:06 -0800 Subject: Remove debug output from PythonPartitioner. --- core/src/main/scala/spark/api/python/PythonPartitioner.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index ef9f808fb2..606a80d1eb 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -16,8 +16,6 @@ class PythonPartitioner(override val numPartitions: Int) extends Partitioner { else { val hashCode = { if (key.isInstanceOf[Array[Byte]]) { - System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) - ) Arrays.hashCode(key.asInstanceOf[Array[Byte]]) } else -- cgit v1.2.3 From 2d98fff0651af4d527f41ba50c01f453fa049464 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 27 Dec 2012 10:13:29 -0800 Subject: Add IPython support to pyspark-shell. Suggested by / based on code from @MLnick --- pyspark/README | 3 +++ pyspark/pyspark/shell.py | 25 +++++++++++++++++-------- pyspark/requirements.txt | 1 + 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/pyspark/README b/pyspark/README index 55490e1a83..461176de7d 100644 --- a/pyspark/README +++ b/pyspark/README @@ -38,6 +38,9 @@ interacting with Java processes. It can be installed from https://github.com/bartdag/py4j; make sure to install a version that contains at least the commits through b7924aabe9. +PySpark requires the `argparse` module, which is included in Python 2.7 +and is is available for Python 2.6 through `pip` or `easy_install`. + PySpark uses the `PYTHONPATH` environment variable to search for Python classes; Py4J should be on this path, along with any libraries used by PySpark programs. `PYTHONPATH` will be automatically shipped to worker diff --git a/pyspark/pyspark/shell.py b/pyspark/pyspark/shell.py index 7ef30894cb..7012884abc 100644 --- a/pyspark/pyspark/shell.py +++ b/pyspark/pyspark/shell.py @@ -1,21 +1,30 @@ """ An interactive shell. """ +import argparse # argparse is avaiable for Python < 2.7 through easy_install. import code import sys from pyspark.context import SparkContext -def main(master='local'): +def main(master='local', ipython=False): sc = SparkContext(master, 'PySparkShell') - print "Spark context available as sc." - code.interact(local={'sc': sc}) + user_ns = {'sc' : sc} + banner = "Spark context avaiable as sc." + if ipython: + import IPython + IPython.embed(user_ns=user_ns, banner2=banner) + else: + print banner + code.interact(local=user_ns) if __name__ == '__main__': - if len(sys.argv) > 1: - master = sys.argv[1] - else: - master = 'local' - main(master) + parser = argparse.ArgumentParser() + parser.add_argument("master", help="Spark master host (default='local')", + nargs='?', type=str, default="local") + parser.add_argument("-i", "--ipython", help="Run IPython shell", + action="store_true") + args = parser.parse_args() + main(args.master, args.ipython) diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt index 48fa2ab105..2464ca0074 100644 --- a/pyspark/requirements.txt +++ b/pyspark/requirements.txt @@ -4,3 +4,4 @@ # install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. # git+git://github.com/bartdag/py4j.git@b7924aabe9c5e63f0a4d8bbd17019534c7ec014e +argparse -- cgit v1.2.3 From 0bc0a60d3001dd231e13057a838d4b6550e5a2b9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 27 Dec 2012 15:37:33 -0800 Subject: Modifications to make sure LocalScheduler terminate cleanly without errors when SparkContext is shutdown, to minimize spurious exception during master failure tests. --- core/src/main/scala/spark/SparkContext.scala | 22 ++++++++++++---------- .../spark/scheduler/local/LocalScheduler.scala | 8 ++++++-- core/src/test/resources/log4j.properties | 2 +- .../src/test/scala/spark/ClosureCleanerSuite.scala | 2 ++ streaming/src/test/resources/log4j.properties | 13 ++++++++----- 5 files changed, 29 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index caa9a1794b..0c8b0078a3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -488,17 +488,19 @@ class SparkContext( if (dagScheduler != null) { dagScheduler.stop() dagScheduler = null + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + // Clean up locally linked files + clearFiles() + clearJars() + SparkEnv.set(null) + ShuffleMapTask.clearCache() + ResultTask.clearCache() + logInfo("Successfully stopped SparkContext") + } else { + logInfo("SparkContext already stopped") } - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - // Clean up locally linked files - clearFiles() - clearJars() - SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() - logInfo("Successfully stopped SparkContext") } /** diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index eb20fe41b2..17a0a4b103 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -81,7 +81,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) logInfo("Finished task " + idInJob) - listener.taskEnded(task, Success, resultToReturn, accumUpdates) + + // If the threadpool has not already been shutdown, notify DAGScheduler + if (!Thread.currentThread().isInterrupted) + listener.taskEnded(task, Success, resultToReturn, accumUpdates) } catch { case t: Throwable => { logError("Exception in task " + idInJob, t) @@ -91,7 +94,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon submitTask(task, idInJob) } else { // TODO: Do something nicer here to return all the way to the user - listener.taskEnded(task, new ExceptionFailure(t), null, null) + if (!Thread.currentThread().isInterrupted) + listener.taskEnded(task, new ExceptionFailure(t), null, null) } } } diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 4c99e450bc..5ed388e91b 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -1,4 +1,4 @@ -# Set everything to be logged to the console +# Set everything to be logged to the file spark-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala index 7c0334d957..dfa2de80e6 100644 --- a/core/src/test/scala/spark/ClosureCleanerSuite.scala +++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala @@ -47,6 +47,8 @@ object TestObject { val nums = sc.parallelize(Array(1, 2, 3, 4)) val answer = nums.map(_ + x).reduce(_ + _) sc.stop() + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") return answer } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 02fe16866e..33bafebaab 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,8 +1,11 @@ -# Set everything to be logged to the console -log4j.rootCategory=WARN, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +# Set everything to be logged to the file streaming-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=streaming-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN + -- cgit v1.2.3 From 85b8f2c64f0fc4be5645d8736629fc082cb3587b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 27 Dec 2012 17:55:33 -0800 Subject: Add epydoc API documentation for PySpark. --- docs/README.md | 8 +- docs/_layouts/global.html | 10 ++- docs/_plugins/copy_api_dirs.rb | 17 ++++ pyspark/epydoc.conf | 19 ++++ pyspark/pyspark/context.py | 24 +++++ pyspark/pyspark/rdd.py | 195 ++++++++++++++++++++++++++++++++++++++--- 6 files changed, 254 insertions(+), 19 deletions(-) create mode 100644 pyspark/epydoc.conf diff --git a/docs/README.md b/docs/README.md index 092153070e..887f407f18 100644 --- a/docs/README.md +++ b/docs/README.md @@ -25,10 +25,12 @@ To mark a block of code in your markdown to be syntax highlighted by jekyll duri // supported languages too. {% endhighlight %} -## Scaladoc +## API Docs (Scaladoc and Epydoc) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. -When you run `jekyll` in the docs directory, it will also copy over the scala doc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. +Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the SPARK_PROJECT_ROOT/pyspark directory. -NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`. +When you run `jekyll` in the docs directory, it will also copy over the scaladoc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs using [epydoc](http://epydoc.sourceforge.net/). + +NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`. Similarly, `SKIP_EPYDOC=1 jekyll` will skip PySpark API doc generation. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 41ad5242c9..43a5fa3e1c 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -49,8 +49,14 @@
  • Java
  • - -
  • API (Scaladoc)
  • + +
  • Quick Start
  • Scala
  • Java
  • +
  • Python
  • diff --git a/docs/api.md b/docs/api.md index 43548b223c..b9c93ac5e8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -8,3 +8,4 @@ Here you can find links to the Scaladoc generated for the Spark sbt subprojects. - [Core](api/core/index.html) - [Examples](api/examples/index.html) - [Bagel](api/bagel/index.html) +- [PySpark](api/pyspark/index.html) diff --git a/docs/index.md b/docs/index.md index ed9953a590..33ab58a962 100644 --- a/docs/index.md +++ b/docs/index.md @@ -7,11 +7,11 @@ title: Spark Overview TODO(andyk): Rewrite to make the Java API a first class part of the story. {% endcomment %} -Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an -interpreter. It provides clean, language-integrated APIs in Scala and Java, with a rich array of parallel operators. Spark can -run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager, +Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an interpreter. +It provides clean, language-integrated APIs in Scala, Java, and Python, with a rich array of parallel operators. +Spark can run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager, [Hadoop YARN](http://hadoop.apache.org/docs/r2.0.1-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html), -Amazon EC2, or without an independent resource manager ("standalone mode"). +Amazon EC2, or without an independent resource manager ("standalone mode"). # Downloading @@ -59,6 +59,7 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Quick Start](quick-start.html): a quick introduction to the Spark API; start here! * [Spark Programming Guide](scala-programming-guide.html): an overview of Spark concepts, and details on the Scala API * [Java Programming Guide](java-programming-guide.html): using Spark from Java +* [Python Programming Guide](python-programming-guide.html): using Spark from Python **Deployment guides:** @@ -72,7 +73,7 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Configuration](configuration.html): customize Spark via its configuration system * [Tuning Guide](tuning.html): best practices to optimize performance and memory use -* [API Docs (Scaladoc)](api/core/index.html) +* API Docs: [Java/Scala (Scaladoc)](api/core/index.html) and [Python (Epydoc)](api/pyspark/index.html) * [Bagel](bagel-programming-guide.html): an implementation of Google's Pregel on Spark * [Contributing to Spark](contributing-to-spark.html) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md new file mode 100644 index 0000000000..b7c747f905 --- /dev/null +++ b/docs/python-programming-guide.md @@ -0,0 +1,74 @@ +--- +layout: global +title: Python Programming Guide +--- + + +The Spark Python API (PySpark) exposes most of the Spark features available in the Scala version to Python. +To learn the basics of Spark, we recommend reading through the +[Scala programming guide](scala-programming-guide.html) first; it should be +easy to follow even if you don't know Scala. +This guide will show how to use the Spark features described there in Python. + +# Key Differences in the Python API + +There are a few key differences between the Python and Scala APIs: + +* Python is dynamically typed, so RDDs can hold objects of different types. +* PySpark does not currently support the following Spark features: + - Accumulators + - Special functions on RRDs of doubles, such as `mean` and `stdev` + - Approximate jobs / functions, such as `countApprox` and `sumApprox`. + - `lookup` + - `mapPartitionsWithSplit` + - `persist` at storage levels other than `MEMORY_ONLY` + - `sample` + - `sort` + + +# Installing and Configuring PySpark + +PySpark requires Python 2.6 or higher. +PySpark jobs are executed using a standard cPython interpreter in order to support Python modules that use C extensions. +We have not tested PySpark with Python 3 or with alternative Python interpreters, such as [PyPy](http://pypy.org/) or [Jython](http://www.jython.org/). +By default, PySpark's scripts will run programs using `python`; an alternate Python executable may be specified by setting the `PYSPARK_PYTHON` environment variable in `conf/spark-env.sh`. + +All of PySpark's library dependencies, including [Py4J](http://py4j.sourceforge.net/), are bundled with PySpark and automatically imported. + +Standalone PySpark jobs should be run using the `run-pyspark` script, which automatically configures the Java and Python environmnt using the settings in `conf/spark-env.sh`. +The script automatically adds the `pyspark` package to the `PYTHONPATH`. + + +# Interactive Use + +PySpark's `pyspark-shell` script provides a simple way to learn the API: + +{% highlight python %} +>>> words = sc.textFile("/usr/share/dict/words") +>>> words.filter(lambda w: w.startswith("spar")).take(5) +[u'spar', u'sparable', u'sparada', u'sparadrap', u'sparagrass'] +{% endhighlight %} + +# Standalone Use + +PySpark can also be used from standalone Python scripts by creating a SparkContext in the script and running the script using the `run-pyspark` script in the `pyspark` directory. +The Quick Start guide includes a [complete example](quick-start.html#a-standalone-job-in-python) of a standalone Python job. + +Code dependencies can be deployed by listing them in the `pyFiles` option in the SparkContext constructor: + +{% highlight python %} +from pyspark import SparkContext +sc = SparkContext("local", "Job Name", pyFiles=['MyFile.py', 'lib.zip', 'app.egg']) +{% endhighlight %} + +Files listed here will be added to the `PYTHONPATH` and shipped to remote worker machines. +Code dependencies can be added to an existing SparkContext using its `addPyFile()` method. + +# Where to Go from Here + +PySpark includes several sample programs using the Python API in `pyspark/examples`. +You can run them by passing the files to the `pyspark-run` script included in PySpark -- for example `./pyspark-run examples/wordcount.py`. +Each example program prints usage help when run without any arguments. + +We currently provide [API documentation](api/pyspark/index.html) for the Python API as Epydoc. +Many of the RDD method descriptions contain [doctests](http://docs.python.org/2/library/doctest.html) that provide additional usage examples. diff --git a/docs/quick-start.md b/docs/quick-start.md index defdb34836..c859c31b09 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -6,7 +6,8 @@ title: Quick Start * This will become a table of contents (this text will be scraped). {:toc} -This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will need much for this), then show how to write standalone jobs in Scala and Java. See the [programming guide](scala-programming-guide.html) for a fuller reference. +This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will need much for this), then show how to write standalone jobs in Scala, Java, and Python. +See the [programming guide](scala-programming-guide.html) for a more complete reference. To follow along with this guide, you only need to have successfully built Spark on one machine. Simply go into your Spark directory and run: @@ -230,3 +231,40 @@ Lines with a: 8422, Lines with b: 1836 {% endhighlight %} This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS. + +# A Standalone Job In Python +Now we will show how to write a standalone job using the Python API (PySpark). + +As an example, we'll create a simple Spark job, `SimpleJob.py`: + +{% highlight python %} +"""SimpleJob.py""" +from pyspark import SparkContext + +logFile = "/var/log/syslog" # Should be some file on your system +sc = SparkContext("local", "Simple job") +logData = sc.textFile(logFile).cache() + +numAs = logData.filter(lambda s: 'a' in s).count() +numBs = logData.filter(lambda s: 'b' in s).count() + +print "Lines with a: %i, lines with b: %i" % (numAs, numBs) +{% endhighlight %} + + +This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. +Like in the Scala and Java examples, we use a SparkContext to create RDDs. +We can pass Python functions to Spark, which are automatically serialized along with any variables that they reference. +For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide). +`SimpleJob` is simple enough that we do not need to specify any code dependencies. + +We can run this job using the `run-pyspark` script in `$SPARK_HOME/pyspark`: + +{% highlight python %} +$ cd $SPARK_HOME +$ ./pyspark/run-pyspark SimpleJob.py +... +Lines with a: 8422, Lines with b: 1836 +{% endhighlight python %} + +This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS. diff --git a/pyspark/README b/pyspark/README deleted file mode 100644 index d8d521c72c..0000000000 --- a/pyspark/README +++ /dev/null @@ -1,42 +0,0 @@ -# PySpark - -PySpark is a Python API for Spark. - -PySpark jobs are writen in Python and executed using a standard Python -interpreter; this supports modules that use Python C extensions. The -API is based on the Spark Scala API and uses regular Python functions -and lambdas to support user-defined functions. PySpark supports -interactive use through a standard Python interpreter; it can -automatically serialize closures and ship them to worker processes. - -PySpark is built on top of the Spark Java API. Data is uniformly -represented as serialized Python objects and stored in Spark Java -processes, which communicate with PySpark worker processes over pipes. - -## Features - -PySpark supports most of the Spark API, including broadcast variables. -RDDs are dynamically typed and can hold any Python object. - -PySpark does not support: - -- Special functions on RDDs of doubles -- Accumulators - -## Examples and Documentation - -The PySpark source contains docstrings and doctests that document its -API. The public classes are in `context.py` and `rdd.py`. - -The `pyspark/pyspark/examples` directory contains a few complete -examples. - -## Installing PySpark -# -To use PySpark, `SPARK_HOME` should be set to the location of the Spark -package. - -## Running PySpark - -The easiest way to run PySpark is to use the `run-pyspark` and -`pyspark-shell` scripts, which are included in the `pyspark` directory. diff --git a/pyspark/examples/kmeans.py b/pyspark/examples/kmeans.py new file mode 100644 index 0000000000..9cc366f03c --- /dev/null +++ b/pyspark/examples/kmeans.py @@ -0,0 +1,49 @@ +import sys + +from pyspark.context import SparkContext +from numpy import array, sum as np_sum + + +def parseVector(line): + return array([float(x) for x in line.split(' ')]) + + +def closestPoint(p, centers): + bestIndex = 0 + closest = float("+inf") + for i in range(len(centers)): + tempDist = np_sum((p - centers[i]) ** 2) + if tempDist < closest: + closest = tempDist + bestIndex = i + return bestIndex + + +if __name__ == "__main__": + if len(sys.argv) < 5: + print >> sys.stderr, \ + "Usage: PythonKMeans " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + lines = sc.textFile(sys.argv[2]) + data = lines.map(parseVector).cache() + K = int(sys.argv[3]) + convergeDist = float(sys.argv[4]) + + kPoints = data.takeSample(False, K, 34) + tempDist = 1.0 + + while tempDist > convergeDist: + closest = data.map( + lambda p : (closestPoint(p, kPoints), (p, 1))) + pointStats = closest.reduceByKey( + lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) + newPoints = pointStats.map( + lambda (x, (y, z)): (x, y / z)).collect() + + tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) + + for (x, y) in newPoints: + kPoints[x] = y + + print "Final centers: " + str(kPoints) diff --git a/pyspark/examples/pi.py b/pyspark/examples/pi.py new file mode 100644 index 0000000000..348bbc5dce --- /dev/null +++ b/pyspark/examples/pi.py @@ -0,0 +1,20 @@ +import sys +from random import random +from operator import add +from pyspark.context import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonPi []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonPi") + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + n = 100000 * slices + def f(_): + x = random() * 2 - 1 + y = random() * 2 - 1 + return 1 if x ** 2 + y ** 2 < 1 else 0 + count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + print "Pi is roughly %f" % (4.0 * count / n) diff --git a/pyspark/examples/tc.py b/pyspark/examples/tc.py new file mode 100644 index 0000000000..9630e72b47 --- /dev/null +++ b/pyspark/examples/tc.py @@ -0,0 +1,49 @@ +import sys +from random import Random +from pyspark.context import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonTC") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelize(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.map(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/pyspark/examples/wordcount.py b/pyspark/examples/wordcount.py new file mode 100644 index 0000000000..8365c070e8 --- /dev/null +++ b/pyspark/examples/wordcount.py @@ -0,0 +1,17 @@ +import sys +from operator import add +from pyspark.context import SparkContext + +if __name__ == "__main__": + if len(sys.argv) < 3: + print >> sys.stderr, \ + "Usage: PythonWordCount " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonWordCount") + lines = sc.textFile(sys.argv[2], 1) + counts = lines.flatMap(lambda x: x.split(' ')) \ + .map(lambda x: (x, 1)) \ + .reduceByKey(add) + output = counts.collect() + for (word, count) in output: + print "%s : %i" % (word, count) diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py index 549c2d2711..8f8402b62b 100644 --- a/pyspark/pyspark/__init__.py +++ b/pyspark/pyspark/__init__.py @@ -1,3 +1,9 @@ import sys import os sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "pyspark/lib/py4j0.7.egg")) + + +from pyspark.context import SparkContext + + +__all__ = ["SparkContext"] diff --git a/pyspark/pyspark/examples/__init__.py b/pyspark/pyspark/examples/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/pyspark/examples/kmeans.py deleted file mode 100644 index 9cc366f03c..0000000000 --- a/pyspark/pyspark/examples/kmeans.py +++ /dev/null @@ -1,49 +0,0 @@ -import sys - -from pyspark.context import SparkContext -from numpy import array, sum as np_sum - - -def parseVector(line): - return array([float(x) for x in line.split(' ')]) - - -def closestPoint(p, centers): - bestIndex = 0 - closest = float("+inf") - for i in range(len(centers)): - tempDist = np_sum((p - centers[i]) ** 2) - if tempDist < closest: - closest = tempDist - bestIndex = i - return bestIndex - - -if __name__ == "__main__": - if len(sys.argv) < 5: - print >> sys.stderr, \ - "Usage: PythonKMeans " - exit(-1) - sc = SparkContext(sys.argv[1], "PythonKMeans") - lines = sc.textFile(sys.argv[2]) - data = lines.map(parseVector).cache() - K = int(sys.argv[3]) - convergeDist = float(sys.argv[4]) - - kPoints = data.takeSample(False, K, 34) - tempDist = 1.0 - - while tempDist > convergeDist: - closest = data.map( - lambda p : (closestPoint(p, kPoints), (p, 1))) - pointStats = closest.reduceByKey( - lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) - newPoints = pointStats.map( - lambda (x, (y, z)): (x, y / z)).collect() - - tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) - - for (x, y) in newPoints: - kPoints[x] = y - - print "Final centers: " + str(kPoints) diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py deleted file mode 100644 index 348bbc5dce..0000000000 --- a/pyspark/pyspark/examples/pi.py +++ /dev/null @@ -1,20 +0,0 @@ -import sys -from random import random -from operator import add -from pyspark.context import SparkContext - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonPi []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonPi") - slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 - n = 100000 * slices - def f(_): - x = random() * 2 - 1 - y = random() * 2 - 1 - return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) - print "Pi is roughly %f" % (4.0 * count / n) diff --git a/pyspark/pyspark/examples/tc.py b/pyspark/pyspark/examples/tc.py deleted file mode 100644 index 9630e72b47..0000000000 --- a/pyspark/pyspark/examples/tc.py +++ /dev/null @@ -1,49 +0,0 @@ -import sys -from random import Random -from pyspark.context import SparkContext - -numEdges = 200 -numVertices = 100 -rand = Random(42) - - -def generateGraph(): - edges = set() - while len(edges) < numEdges: - src = rand.randrange(0, numEdges) - dst = rand.randrange(0, numEdges) - if src != dst: - edges.add((src, dst)) - return edges - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonTC []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonTC") - slices = sys.argv[2] if len(sys.argv) > 2 else 2 - tc = sc.parallelize(generateGraph(), slices).cache() - - # Linear transitive closure: each round grows paths by one edge, - # by joining the graph's edges with the already-discovered paths. - # e.g. join the path (y, z) from the TC with the edge (x, y) from - # the graph to obtain the path (x, z). - - # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.map(lambda (x, y): (y, x)) - - oldCount = 0L - nextCount = tc.count() - while True: - oldCount = nextCount - # Perform the join, obtaining an RDD of (y, (z, x)) pairs, - # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) - tc = tc.union(new_edges).distinct().cache() - nextCount = tc.count() - if nextCount == oldCount: - break - - print "TC has %i edges" % tc.count() diff --git a/pyspark/pyspark/examples/wordcount.py b/pyspark/pyspark/examples/wordcount.py deleted file mode 100644 index 8365c070e8..0000000000 --- a/pyspark/pyspark/examples/wordcount.py +++ /dev/null @@ -1,17 +0,0 @@ -import sys -from operator import add -from pyspark.context import SparkContext - -if __name__ == "__main__": - if len(sys.argv) < 3: - print >> sys.stderr, \ - "Usage: PythonWordCount " - exit(-1) - sc = SparkContext(sys.argv[1], "PythonWordCount") - lines = sc.textFile(sys.argv[2], 1) - counts = lines.flatMap(lambda x: x.split(' ')) \ - .map(lambda x: (x, 1)) \ - .reduceByKey(add) - output = counts.collect() - for (word, count) in output: - print "%s : %i" % (word, count) -- cgit v1.2.3 From 6ee1ff2663cf1f776dd33e448548a8ddcf974dc6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 22:22:56 +0000 Subject: Fix bug in pyspark.serializers.batch; add .gitignore. --- pyspark/.gitignore | 2 ++ pyspark/pyspark/rdd.py | 4 +++- pyspark/pyspark/serializers.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 pyspark/.gitignore diff --git a/pyspark/.gitignore b/pyspark/.gitignore new file mode 100644 index 0000000000..5c56e638f9 --- /dev/null +++ b/pyspark/.gitignore @@ -0,0 +1,2 @@ +*.pyc +docs/ diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 111476d274..20f84b2dd0 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -695,7 +695,9 @@ def _test(): import doctest from pyspark.context import SparkContext globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest') + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) doctest.testmod(globs=globs) globs['sc'].stop() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 4ed925697c..8b08f7ef0f 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -24,7 +24,7 @@ def batched(iterator, batchSize): if count == batchSize: yield Batch(items) items = [] - count = [] + count = 0 if items: yield Batch(items) -- cgit v1.2.3 From 26186e2d259f3aa2db9c8594097fd342107ce147 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 15:34:57 -0800 Subject: Use batching in pyspark parallelize(); fix cartesian() --- pyspark/pyspark/context.py | 4 +++- pyspark/pyspark/rdd.py | 31 +++++++++++++++---------------- pyspark/pyspark/serializers.py | 23 +++++++++++++---------- 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index b90596ecc2..6172d69dcf 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length +from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD from py4j.java_collections import ListConverter @@ -91,6 +91,8 @@ class SparkContext(object): # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) atexit.register(lambda: os.unlink(tempFile.name)) + if self.batchSize != 1: + c = batched(c, self.batchSize) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 20f84b2dd0..203f7377d2 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -2,7 +2,7 @@ import atexit from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap +from itertools import chain, ifilter, imap, product import operator import os import shlex @@ -123,12 +123,6 @@ class RDD(object): >>> rdd = sc.parallelize([1, 1, 2, 3]) >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] - - Union of batched and unbatched RDDs (internal test): - - >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])]) - >>> rdd.union(batchedRDD).collect() - [1, 1, 2, 3, 1, 2, 3, 4, 5] """ return RDD(self._jrdd.union(other._jrdd), self.ctx) @@ -168,7 +162,18 @@ class RDD(object): >>> sorted(rdd.cartesian(rdd).collect()) [(1, 1), (1, 2), (2, 1), (2, 2)] """ - return RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + # Due to batching, we can't use the Java cartesian method. + java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + def unpack_batches(pair): + (x, y) = pair + if type(x) == Batch or type(y) == Batch: + xs = x.items if type(x) == Batch else [x] + ys = y.items if type(y) == Batch else [y] + for pair in product(xs, ys): + yield pair + else: + yield pair + return java_cartesian.flatMap(unpack_batches) def groupBy(self, f, numSplits=None): """ @@ -293,8 +298,6 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).count() 3 - >>> sc.parallelize([Batch([2, 3, 4])]).count() - 3 """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() @@ -667,12 +670,8 @@ class PipelinedRDD(RDD): if not self._bypass_serializer and self.ctx.batchSize != 1: oldfunc = self.func batchSize = self.ctx.batchSize - if batchSize == -1: # unlimited batch size - def batched_func(iterator): - yield Batch(list(oldfunc(iterator))) - else: - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) func = batched_func cmds = [func, self._bypass_serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 8b08f7ef0f..9a5151ea00 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -16,17 +16,20 @@ class Batch(object): def batched(iterator, batchSize): - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: + if batchSize == -1: # unlimited batch size + yield Batch(list(iterator)) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = 0 + if items: yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) def dump_pickle(obj): -- cgit v1.2.3 From 59195c68ec37acf20d527189ed757397b273a207 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 16:01:03 -0800 Subject: Update PySpark for compatibility with TaskContext. --- core/src/main/scala/spark/api/python/PythonRDD.scala | 13 +++++-------- pyspark/pyspark/rdd.py | 3 ++- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f76616a4c4..dc48378fdc 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -8,10 +8,7 @@ import scala.io.Source import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast -import spark.SparkEnv -import spark.Split -import spark.RDD -import spark.OneToOneDependency +import spark._ import spark.rdd.PipedRDD @@ -34,7 +31,7 @@ private[spark] class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[Array[Byte]] = { + override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) @@ -74,7 +71,7 @@ private[spark] class PythonRDD[T: ClassManifest]( out.println(elem) } out.flush() - for (elem <- parent.iterator(split)) { + for (elem <- parent.iterator(split, context)) { PythonRDD.writeAsPickle(elem, dOut) } dOut.flush() @@ -123,8 +120,8 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Array[Byte], Array[Byte])](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = - prev.iterator(split).grouped(2).map { + override def compute(split: Split, context: TaskContext) = + prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 203f7377d2..21dda31c4e 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -335,9 +335,10 @@ class RDD(object): """ items = [] splits = self._jrdd.splits() + taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) while len(items) < num and splits: split = splits.pop(0) - iterator = self._jrdd.iterator(split) + iterator = self._jrdd.iterator(split, taskContext) items.extend(self._collect_iterator_through_file(iterator)) return items[:num] -- cgit v1.2.3 From 39dd953fd88e9aa7335603ab452d9c1bed4ba67a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 17:06:50 -0800 Subject: Add test for pyspark.RDD.saveAsTextFile(). --- pyspark/pyspark/rdd.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 21dda31c4e..cbffb6cc1f 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -351,10 +351,17 @@ class RDD(object): """ return self.take(1)[0] - # TODO: add test and fix for use with Batch def saveAsTextFile(self, path): """ Save this RDD as a text file, using string representations of elements. + + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.close() + >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) + >>> from fileinput import input + >>> from glob import glob + >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) + '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' """ def func(iterator): return (str(x).encode("utf-8") for x in iterator) -- cgit v1.2.3 From 099898b43955d99351ec94d4a373de854bf7edf7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 17:52:47 -0800 Subject: Port LR example to PySpark using numpy. This version of the example crashes after the first iteration with "OverflowError: math range error" because Python's math.exp() behaves differently than Scala's; see SPARK-646. --- pyspark/examples/lr.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100755 pyspark/examples/lr.py diff --git a/pyspark/examples/lr.py b/pyspark/examples/lr.py new file mode 100755 index 0000000000..5fca0266b8 --- /dev/null +++ b/pyspark/examples/lr.py @@ -0,0 +1,57 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from collections import namedtuple +from math import exp +from os.path import realpath +import sys + +import numpy as np +from pyspark.context import SparkContext + + +N = 100000 # Number of data points +D = 10 # Number of dimensions +R = 0.7 # Scaling factor +ITERATIONS = 5 +np.random.seed(42) + + +DataPoint = namedtuple("DataPoint", ['x', 'y']) +from lr import DataPoint # So that DataPoint is properly serialized + + +def generateData(): + def generatePoint(i): + y = -1 if i % 2 == 0 else 1 + x = np.random.normal(size=D) + (y * R) + return DataPoint(x, y) + return [generatePoint(i) for i in range(N)] + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonLR []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + points = sc.parallelize(generateData(), slices).cache() + + # Initialize w to a random value + w = 2 * np.random.ranf(size=D) - 1 + print "Initial w: " + str(w) + + def add(x, y): + x += y + return x + + for i in range(1, ITERATIONS + 1): + print "On iteration %i" % i + + gradient = points.map(lambda p: + (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x + ).reduce(add) + w -= gradient + + print "Final w: " + str(w) -- cgit v1.2.3 From 9e644402c155b5fc68794a17c36ddd19d3242f4f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 29 Dec 2012 18:31:51 -0800 Subject: Improved jekyll and scala docs. Made many classes and method private to remove them from scala docs. --- core/src/main/scala/spark/RDD.scala | 1 - docs/_plugins/copy_api_dirs.rb | 4 +- docs/streaming-programming-guide.md | 56 ++--- .../main/scala/spark/streaming/Checkpoint.scala | 5 +- .../src/main/scala/spark/streaming/DStream.scala | 249 +++++++++++++-------- .../scala/spark/streaming/FlumeInputDStream.scala | 2 +- .../src/main/scala/spark/streaming/Interval.scala | 1 + streaming/src/main/scala/spark/streaming/Job.scala | 2 + .../main/scala/spark/streaming/JobManager.scala | 1 + .../spark/streaming/NetworkInputDStream.scala | 8 +- .../spark/streaming/NetworkInputTracker.scala | 8 +- .../spark/streaming/PairDStreamFunctions.scala | 4 +- .../src/main/scala/spark/streaming/Scheduler.scala | 7 +- .../scala/spark/streaming/StreamingContext.scala | 43 ++-- .../scala/spark/streaming/examples/GrepRaw.scala | 2 +- .../streaming/examples/TopKWordCountRaw.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 2 +- .../examples/clickstream/PageViewStream.scala | 2 +- .../test/scala/spark/streaming/TestSuiteBase.scala | 2 +- 19 files changed, 233 insertions(+), 168 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 59e50a0b6b..1574533430 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -101,7 +101,6 @@ abstract class RDD[T: ClassManifest]( val partitioner: Option[Partitioner] = None - // ======================================================================= // Methods and fields available on all RDDs // ======================================================================= diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index e61c105449..7654511eeb 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -2,7 +2,7 @@ require 'fileutils' include FileUtils if ENV['SKIP_SCALADOC'] != '1' - projects = ["core", "examples", "repl", "bagel"] + projects = ["core", "examples", "repl", "bagel", "streaming"] puts "Moving to project root and building scaladoc." curr_dir = pwd @@ -11,7 +11,7 @@ if ENV['SKIP_SCALADOC'] != '1' puts "Running sbt/sbt doc from " + pwd + "; this may take a few minutes..." puts `sbt/sbt doc` - puts "moving back into docs dir." + puts "Moving back into docs dir." cd("docs") # Copy over the scaladoc from each project into the docs directory. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 90916545bc..7c421ac70f 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2,33 +2,44 @@ layout: global title: Streaming (Alpha) Programming Guide --- + +{:toc} + +# Overview +A Spark Streaming application is very similar to a Spark application; it consists of a *driver program* that runs the user's `main` function and continuous executes various *parallel operations* on input streams of data. The main abstraction Spark Streaming provides is a *discretized stream* (DStream), which is a continuous sequence of RDDs (distributed collection of elements) representing a continuous stream of data. DStreams can created from live incoming data (such as data from a socket, Kafka, etc.) or it can be generated by transformation of existing DStreams using parallel operators like map, reduce, and window. The basic processing model is as follows: +(i) While a Spark Streaming driver program is running, the system receives data from various sources and and divides the data into batches. Each batch of data is treated as a RDD, that is a immutable and parallel collection of data. These input data RDDs are automatically persisted in memory (serialized by default) and replicated to two nodes for fault-tolerance. This sequence of RDDs is collectively referred to as an InputDStream. +(ii) Data received by InputDStreams are processed processed using DStream operations. Since all data is represented as RDDs and all DStream operations as RDD operations, data is automatically recovered in the event of node failures. + +This guide shows some how to start programming with DStreams. + # Initializing Spark Streaming The first thing a Spark Streaming program must do is create a `StreamingContext` object, which tells Spark how to access a cluster. A `StreamingContext` can be created from an existing `SparkContext`, or directly: {% highlight scala %} -new StreamingContext(master, jobName, [sparkHome], [jars]) -new StreamingContext(sparkContext) -{% endhighlight %} - -Once a context is instantiated, the batch interval must be set: +import spark.SparkContext +import SparkContext._ -{% highlight scala %} -context.setBatchDuration(Milliseconds(2000)) +new StreamingContext(master, frameworkName, batchDuration) +new StreamingContext(sparkContext, batchDuration) {% endhighlight %} +The `master` parameter is either the [Mesos master URL](running-on-mesos.html) (for running on a cluster)or the special "local" string (for local mode) that is used to create a Spark Context. For more information about this please refer to the [Spark programming guide](scala-programming-guide.html). -# DStreams - Discretized Streams -The primary abstraction in Spark Streaming is a DStream. A DStream represents distributed collection which is computed periodically according to a specified batch interval. DStream's can be chained together to create complex chains of transformation on streaming data. DStreams can be created by operating on existing DStreams or from an input source. To creating DStreams from an input source, use the StreamingContext: + +# Creating Input Sources - InputDStreams +The StreamingContext is used to creating InputDStreams from input sources: {% highlight scala %} -context.neworkStream(host, port) // A stream that reads from a socket -context.flumeStream(hosts, ports) // A stream populated by a Flume flow +context.neworkStream(host, port) // Creates a stream that uses a TCP socket to read data from : +context.flumeStream(host, port) // Creates a stream populated by a Flume flow {% endhighlight %} -# DStream Operators +A complete list of input sources is available in the [DStream API doc](api/streaming/index.html#spark.streaming.StreamingContext). + +## DStream Operations Once an input stream has been created, you can transform it using _stream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the stream by writing data out to an external source. -## Transformations +### Transformations DStreams support many of the transformations available on normal Spark RDD's: @@ -73,20 +84,13 @@ DStreams support many of the transformations available on normal Spark RDD's: cogroup(otherStream, [numTasks]) When called on streams of type (K, V) and (K, W), returns a stream of (K, Seq[V], Seq[W]) tuples. This operation is also called groupWith. - - -DStreams also support the following additional transformations: - -
    reduce(func) Create a new single-element stream by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel.
    - -## Windowed Transformations -Spark streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. +Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. @@ -128,7 +132,7 @@ Spark streaming features windowed computations, which allow you to report statis
    TransformationMeaning
    -## Output Operators +### Output Operators When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: @@ -140,22 +144,22 @@ When an output operator is called, it triggers the computation of a stream. Curr - + - + - + - + diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 770f7b0cc0..11a7232d7b 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -8,6 +8,7 @@ import org.apache.hadoop.conf.Configuration import java.io._ +private[streaming] class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master @@ -30,6 +31,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) /** * Convenience class to speed up the writing of graph checkpoint to file */ +private[streaming] class CheckpointWriter(checkpointDir: String) extends Logging { val file = new Path(checkpointDir, "graph") val conf = new Configuration() @@ -65,7 +67,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging { } - +private[streaming] object CheckpointReader extends Logging { def read(path: String): Checkpoint = { @@ -103,6 +105,7 @@ object CheckpointReader extends Logging { } } +private[streaming] class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) { override def resolveClass(desc: ObjectStreamClass): Class[_] = { try { diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d5048aeed7..3834b57ed3 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous * sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]] * for more details on RDDs). DStreams can either be created from live data (such as, data from - * HDFS. Kafka or Flume) or it can be generated by transformation existing DStreams using operations + * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each * DStream periodically generates a RDD, either from live data or by transforming the RDD generated * by a parent DStream. @@ -38,33 +38,28 @@ import org.apache.hadoop.conf.Configuration * - A function that is used to generate an RDD after each time interval */ -case class DStreamCheckpointData(rdds: HashMap[Time, Any]) - -abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) -extends Serializable with Logging { +abstract class DStream[T: ClassManifest] ( + @transient protected[streaming] var ssc: StreamingContext + ) extends Serializable with Logging { initLogging() - /** - * ---------------------------------------------- - * Methods that must be implemented by subclasses - * ---------------------------------------------- - */ + // ======================================================================= + // Methods that should be implemented by subclasses of DStream + // ======================================================================= - // Time interval at which the DStream generates an RDD + /** Time interval after which the DStream generates a RDD */ def slideTime: Time - // List of parent DStreams on which this DStream depends on + /** List of parent DStreams on which this DStream depends on */ def dependencies: List[DStream[_]] - // Key method that computes RDD for a valid time + /** Method that generates a RDD for the given time */ def compute (validTime: Time): Option[RDD[T]] - /** - * --------------------------------------- - * Other general fields and methods of DStream - * --------------------------------------- - */ + // ======================================================================= + // Methods and fields available on all DStreams + // ======================================================================= // RDDs generated, marked as protected[streaming] so that testsuites can access it @transient @@ -87,12 +82,15 @@ extends Serializable with Logging { // Reference to whole DStream graph protected[streaming] var graph: DStreamGraph = null - def isInitialized = (zeroTime != null) + protected[streaming] def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created - def parentRememberDuration = rememberDuration + protected[streaming] def parentRememberDuration = rememberDuration + + /** Returns the StreamingContext associated with this DStream */ + def context() = ssc - // Set caching level for the RDDs created by this DStream + /** Persists the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { throw new UnsupportedOperationException( @@ -102,11 +100,16 @@ extends Serializable with Logging { this } + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_SER) - - // Turn on the default caching level for this RDD + + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): DStream[T] = persist() + /** + * Enable periodic checkpointing of RDDs of this DStream + * @param interval Time interval after which generated RDD will be checkpointed + */ def checkpoint(interval: Time): DStream[T] = { if (isInitialized) { throw new UnsupportedOperationException( @@ -285,7 +288,7 @@ extends Serializable with Logging { * Generates a SparkStreaming job for the given time. This is an internal method that * should not be called directly. This default implementation creates a job * that materializes the corresponding RDD. Subclasses of DStream may override this - * (eg. PerRDDForEachDStream). + * (eg. ForEachDStream). */ protected[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { @@ -420,65 +423,96 @@ extends Serializable with Logging { generatedRDDs = new HashMap[Time, RDD[T]] () } - /** - * -------------- - * DStream operations - * -------------- - */ + // ======================================================================= + // DStream operations + // ======================================================================= + + /** Returns a new DStream by applying a function to all elements of this DStream. */ def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { new MappedDStream(this, ssc.sc.clean(mapFunc)) } + /** + * Returns a new DStream by applying a function to all elements of this DStream, + * and then flattening the results + */ def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) } + /** Returns a new DStream containing only the elements that satisfy a predicate. */ def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) + /** + * Return a new DStream in which each RDD is generated by applying glom() to each RDD of + * this DStream. Applying glom() to an RDD coalesces all elements within each partition into + * an array. + */ def glom(): DStream[Array[T]] = new GlommedDStream(this) - def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]): DStream[U] = { - new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc)) + /** + * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs + * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition + * of the RDD. + */ + def mapPartitions[U: ClassManifest]( + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean = false + ): DStream[U] = { + new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc), preservePartitioning) } - def reduce(reduceFunc: (T, T) => T): DStream[T] = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + /** + * Returns a new DStream in which each RDD has a single element generated by reducing each RDD + * of this DStream. + */ + def reduce(reduceFunc: (T, T) => T): DStream[T] = + this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + /** + * Returns a new DStream in which each RDD has a single element generated by counting each RDD + * of this DStream. + */ def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _) - - def collect(): DStream[Seq[T]] = this.map(x => (null, x)).groupByKey(1).map(_._2) - - def foreach(foreachFunc: T => Unit) { - val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc)) - ssc.registerOutputStream(newStream) - newStream - } - def foreachRDD(foreachFunc: RDD[T] => Unit) { - foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) + /** + * Applies a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: RDD[T] => Unit) { + foreach((r: RDD[T], t: Time) => foreachFunc(r)) } - def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { - val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + /** + * Applies a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ + def foreach(foreachFunc: (RDD[T], Time) => Unit) { + val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) ssc.registerOutputStream(newStream) newStream } - def transformRDD[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - transformRDD((r: RDD[T], t: Time) => transformFunc(r)) + /** + * Returns a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { + transform((r: RDD[T], t: Time) => transformFunc(r)) } - def transformRDD[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { + /** + * Returns a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { new TransformedDStream(this, ssc.sc.clean(transformFunc)) } - def toBlockingQueue() = { - val queue = new ArrayBlockingQueue[RDD[T]](10000) - this.foreachRDD(rdd => { - queue.add(rdd) - }) - queue - } - + /** + * Prints the first ten elements of each RDD generated in this DStream. This is an output + * operator, so this DStream will be registered as an output stream and there materialized. + */ def print() { def foreachFunc = (rdd: RDD[T], time: Time) => { val first11 = rdd.take(11) @@ -489,18 +523,42 @@ extends Serializable with Logging { if (first11.size > 10) println("...") println() } - val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc)) + val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) ssc.registerOutputStream(newStream) } + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * The new DStream generates RDDs with the same interval as this DStream. + * @param windowTime width of the window; must be a multiple of this DStream's interval. + * @return + */ def window(windowTime: Time): DStream[T] = window(windowTime, this.slideTime) + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * @param windowTime duration (i.e., width) of the window; + * must be a multiple of this DStream's interval + * @param slideTime sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's interval + */ def window(windowTime: Time, slideTime: Time): DStream[T] = { new WindowedDStream(this, windowTime, slideTime) } + /** + * Returns a new DStream which computed based on tumbling window on this DStream. + * This is equivalent to window(batchTime, batchTime). + * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + */ def tumble(batchTime: Time): DStream[T] = window(batchTime, batchTime) + /** + * Returns a new DStream in which each RDD has a single element generated by reducing all + * elements in a window over this DStream. windowTime and slideTime are as defined in the + * window() operation. This is equivalent to window(windowTime, slideTime).reduce(reduceFunc) + */ def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time): DStream[T] = { this.window(windowTime, slideTime).reduce(reduceFunc) } @@ -516,17 +574,31 @@ extends Serializable with Logging { .map(_._2) } + /** + * Returns a new DStream in which each RDD has a single element generated by counting the number + * of elements in a window over this DStream. windowTime and slideTime are as defined in the + * window() operation. This is equivalent to window(windowTime, slideTime).count() + */ def countByWindow(windowTime: Time, slideTime: Time): DStream[Int] = { this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowTime, slideTime) } + /** + * Returns a new DStream by unifying data of another DStream with this DStream. + * @param that Another DStream having the same interval (i.e., slideTime) as this DStream. + */ def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) - def slice(interval: Interval): Seq[RDD[T]] = { + /** + * Returns all the RDDs defined by the Interval object (both end times included) + */ + protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = { slice(interval.beginTime, interval.endTime) } - // Get all the RDDs between fromTime to toTime (both included) + /** + * Returns all the RDDs between 'fromTime' to 'toTime' (both included) + */ def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() var time = toTime.floor(slideTime) @@ -540,20 +612,26 @@ extends Serializable with Logging { rdds.toSeq } + /** + * Saves each RDD in this DStream as a Sequence file of serialized objects. + */ def saveAsObjectFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) } - this.foreachRDD(saveFunc) + this.foreach(saveFunc) } + /** + * Saves each RDD in this DStream as at text file, using string representation of elements. + */ def saveAsTextFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) } - this.foreachRDD(saveFunc) + this.foreach(saveFunc) } def register() { @@ -561,6 +639,8 @@ extends Serializable with Logging { } } +private[streaming] +case class DStreamCheckpointData(rdds: HashMap[Time, Any]) abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) extends DStream[T](ssc_) { @@ -583,6 +663,7 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContex * TODO */ +private[streaming] class MappedDStream[T: ClassManifest, U: ClassManifest] ( parent: DStream[T], mapFunc: T => U @@ -602,6 +683,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] ( * TODO */ +private[streaming] class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( parent: DStream[T], flatMapFunc: T => Traversable[U] @@ -621,6 +703,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ +private[streaming] class FilteredDStream[T: ClassManifest]( parent: DStream[T], filterFunc: T => Boolean @@ -640,9 +723,11 @@ class FilteredDStream[T: ClassManifest]( * TODO */ +private[streaming] class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( parent: DStream[T], - mapPartFunc: Iterator[T] => Iterator[U] + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean ) extends DStream[U](parent.ssc) { override def dependencies = List(parent) @@ -650,7 +735,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( override def slideTime: Time = parent.slideTime override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc)) + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) } } @@ -659,6 +744,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( * TODO */ +private[streaming] class GlommedDStream[T: ClassManifest](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { @@ -676,6 +762,7 @@ class GlommedDStream[T: ClassManifest](parent: DStream[T]) * TODO */ +private[streaming] class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( parent: DStream[(K,V)], createCombiner: V => C, @@ -702,6 +789,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( * TODO */ +private[streaming] class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( parent: DStream[(K, V)], mapValueFunc: V => U @@ -720,7 +808,7 @@ class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( /** * TODO */ - +private[streaming] class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( parent: DStream[(K, V)], flatMapValueFunc: V => TraversableOnce[U] @@ -779,38 +867,8 @@ class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) * TODO */ -class PerElementForEachDStream[T: ClassManifest] ( - parent: DStream[T], - foreachFunc: T => Unit - ) extends DStream[Unit](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - val sparkJobFunc = { - (iterator: Iterator[T]) => iterator.foreach(foreachFunc) - } - ssc.sc.runJob(rdd, sparkJobFunc) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} - - -/** - * TODO - */ - -class PerRDDForEachDStream[T: ClassManifest] ( +private[streaming] +class ForEachDStream[T: ClassManifest] ( parent: DStream[T], foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { @@ -838,6 +896,7 @@ class PerRDDForEachDStream[T: ClassManifest] ( * TODO */ +private[streaming] class TransformedDStream[T: ClassManifest, U: ClassManifest] ( parent: DStream[T], transformFunc: (RDD[T], Time) => RDD[U] diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala index 2959ce4540..5ac7e5b08e 100644 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala @@ -79,7 +79,7 @@ class SparkFlumeEvent() extends Externalizable { } } -object SparkFlumeEvent { +private[streaming] object SparkFlumeEvent { def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { val event = new SparkFlumeEvent event.event = in diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index ffb7725ac9..fa0b7ce19d 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -1,5 +1,6 @@ package spark.streaming +private[streaming] case class Interval(beginTime: Time, endTime: Time) { def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala index 0bcb6fd8dc..67bd8388bc 100644 --- a/streaming/src/main/scala/spark/streaming/Job.scala +++ b/streaming/src/main/scala/spark/streaming/Job.scala @@ -2,6 +2,7 @@ package spark.streaming import java.util.concurrent.atomic.AtomicLong +private[streaming] class Job(val time: Time, func: () => _) { val id = Job.getNewId() def run(): Long = { @@ -14,6 +15,7 @@ class Job(val time: Time, func: () => _) { override def toString = "streaming job " + id + " @ " + time } +private[streaming] object Job { val id = new AtomicLong(0) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 9bf9251519..fda7264a27 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -5,6 +5,7 @@ import spark.SparkEnv import java.util.concurrent.Executors +private[streaming] class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { class JobHandler(ssc: StreamingContext, job: Job) extends Runnable { diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index 4e4e9fc942..4bf13dd50c 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -40,10 +40,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming } -sealed trait NetworkReceiverMessage -case class StopReceiver(msg: String) extends NetworkReceiverMessage -case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage -case class ReportError(msg: String) extends NetworkReceiverMessage +private[streaming] sealed trait NetworkReceiverMessage +private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage +private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage +private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index b421f795ee..658498dfc1 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -11,10 +11,10 @@ import akka.pattern.ask import akka.util.duration._ import akka.dispatch._ -trait NetworkInputTrackerMessage -case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage -case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage -case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage +private[streaming] sealed trait NetworkInputTrackerMessage +private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage +private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage +private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage class NetworkInputTracker( diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 720e63bba0..f9fef14196 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -281,7 +281,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreachRDD(saveFunc) + self.foreach(saveFunc) } def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( @@ -303,7 +303,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreachRDD(saveFunc) + self.foreach(saveFunc) } private def getKeyClass() = implicitly[ClassManifest[K]].erasure diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 014021be61..fd1fa77a24 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -7,11 +7,8 @@ import spark.Logging import scala.collection.mutable.HashMap -sealed trait SchedulerMessage -case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage - -class Scheduler(ssc: StreamingContext) -extends Logging { +private[streaming] +class Scheduler(ssc: StreamingContext) extends Logging { initLogging() diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index ce47bcb2da..998fea849f 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -48,7 +48,7 @@ class StreamingContext private ( this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration) /** - * Recreates the StreamingContext from a checkpoint file. + * Re-creates a StreamingContext from a checkpoint file. * @param path Path either to the directory that was specified as the checkpoint directory, or * to the checkpoint file 'graph' or 'graph.bk'. */ @@ -61,7 +61,7 @@ class StreamingContext private ( "both SparkContext and checkpoint as null") } - val isCheckpointPresent = (cp_ != null) + protected[streaming] val isCheckpointPresent = (cp_ != null) val sc: SparkContext = { if (isCheckpointPresent) { @@ -71,9 +71,9 @@ class StreamingContext private ( } } - val env = SparkEnv.get + protected[streaming] val env = SparkEnv.get - val graph: DStreamGraph = { + protected[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { cp_.graph.setContext(this) cp_.graph.restoreCheckpointData() @@ -86,10 +86,10 @@ class StreamingContext private ( } } - private[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) - private[streaming] var networkInputTracker: NetworkInputTracker = null + protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) + protected[streaming] var networkInputTracker: NetworkInputTracker = null - private[streaming] var checkpointDir: String = { + protected[streaming] var checkpointDir: String = { if (isCheckpointPresent) { sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(cp_.checkpointDir), true) cp_.checkpointDir @@ -98,9 +98,9 @@ class StreamingContext private ( } } - private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null - private[streaming] var receiverJobThread: Thread = null - private[streaming] var scheduler: Scheduler = null + protected[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null + protected[streaming] var receiverJobThread: Thread = null + protected[streaming] var scheduler: Scheduler = null def remember(duration: Time) { graph.remember(duration) @@ -117,11 +117,11 @@ class StreamingContext private ( } } - private[streaming] def getInitialCheckpoint(): Checkpoint = { + protected[streaming] def getInitialCheckpoint(): Checkpoint = { if (isCheckpointPresent) cp_ else null } - private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() /** * Create an input stream that pulls messages form a Kafka Broker. @@ -188,7 +188,7 @@ class StreamingContext private ( } /** - * This function creates a input stream that monitors a Hadoop-compatible filesystem + * Creates a input stream that monitors a Hadoop-compatible filesystem * for new files and executes the necessary processing on them. */ def fileStream[ @@ -206,7 +206,7 @@ class StreamingContext private ( } /** - * This function create a input stream from an queue of RDDs. In each batch, + * Creates a input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue */ def queueStream[T: ClassManifest]( @@ -231,22 +231,21 @@ class StreamingContext private ( } /** - * This function registers a InputDStream as an input stream that will be - * started (InputDStream.start() called) to get the input data streams. + * Registers an input stream that will be started (InputDStream.start() called) to get the + * input data. */ def registerInputStream(inputStream: InputDStream[_]) { graph.addInputStream(inputStream) } /** - * This function registers a DStream as an output stream that will be - * computed every interval. + * Registers an output stream that will be computed every interval */ def registerOutputStream(outputStream: DStream[_]) { graph.addOutputStream(outputStream) } - def validate() { + protected def validate() { assert(graph != null, "Graph is null") graph.validate() @@ -304,7 +303,7 @@ class StreamingContext private ( object StreamingContext { - def createNewSparkContext(master: String, frameworkName: String): SparkContext = { + protected[streaming] def createNewSparkContext(master: String, frameworkName: String): SparkContext = { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second interval. @@ -318,7 +317,7 @@ object StreamingContext { new PairDStreamFunctions[K, V](stream) } - def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { + protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { time.millis.toString } else if (suffix == null || suffix.length ==0) { @@ -328,7 +327,7 @@ object StreamingContext { } } - def getSparkCheckpointDir(sscCheckpointDir: String): String = { + protected[streaming] def getSparkCheckpointDir(sscCheckpointDir: String): String = { new Path(sscCheckpointDir, UUID.randomUUID.toString).toString } } diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index 6cb2b4c042..7c4ee3b34c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -26,7 +26,7 @@ object GrepRaw { val rawStreams = (1 to numStreams).map(_ => ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = new UnionDStream(rawStreams) - union.filter(_.contains("Alice")).count().foreachRDD(r => + union.filter(_.contains("Alice")).count().foreach(r => println("Grep count: " + r.collect().mkString)) ssc.start() } diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index fe4c2bf155..182dfd8a52 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -38,7 +38,7 @@ object TopKWordCountRaw { val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreachRDD(rdd => { + partialTopKWindowedCounts.foreach(rdd => { val collectedCounts = rdd.collect println("Collected " + collectedCounts.size + " words from partial top words") println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(",")) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index a29c81d437..9bcd30f4d7 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -36,7 +36,7 @@ object WordCountRaw { val union = new UnionDStream(lines.toArray) val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - windowedCounts.foreachRDD(r => println("# unique words = " + r.count())) + windowedCounts.foreach(r => println("# unique words = " + r.count())) ssc.start() } diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala index 68be6b7893..a191321d91 100644 --- a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala +++ b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala @@ -72,7 +72,7 @@ object PageViewStream { case "popularUsersSeen" => // Look for users in our existing dataset and print it out if we have a match pageViews.map(view => (view.userID, 1)) - .foreachRDD((rdd, time) => rdd.join(userList) + .foreach((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) .foreach(u => println("Saw user %s at time %s".format(u, time)))) diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index 8cc2f8ccfc..a44f738957 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -35,7 +35,7 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. */ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) - extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected }) { -- cgit v1.2.3 From 7e0271b4387eaf27cd96f3057ce2465b1271a480 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 30 Dec 2012 15:19:55 -0800 Subject: Refactored a whole lot to push all DStreams into the spark.streaming.dstream package. --- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 1 + .../scala/spark/streaming/CoGroupedDStream.scala | 38 --- .../spark/streaming/ConstantInputDStream.scala | 18 -- .../src/main/scala/spark/streaming/DStream.scala | 276 +-------------------- .../main/scala/spark/streaming/DStreamGraph.scala | 1 + .../main/scala/spark/streaming/DataHandler.scala | 83 ------- .../scala/spark/streaming/FileInputDStream.scala | 109 -------- .../scala/spark/streaming/FlumeInputDStream.scala | 130 ---------- .../spark/streaming/NetworkInputDStream.scala | 156 ------------ .../spark/streaming/NetworkInputTracker.scala | 2 + .../spark/streaming/PairDStreamFunctions.scala | 7 +- .../scala/spark/streaming/QueueInputDStream.scala | 40 --- .../scala/spark/streaming/RawInputDStream.scala | 85 ------- .../spark/streaming/ReducedWindowedDStream.scala | 149 ----------- .../src/main/scala/spark/streaming/Scheduler.scala | 3 - .../scala/spark/streaming/SocketInputDStream.scala | 107 -------- .../main/scala/spark/streaming/StateDStream.scala | 84 ------- .../scala/spark/streaming/StreamingContext.scala | 13 +- .../src/main/scala/spark/streaming/Time.scala | 11 +- .../scala/spark/streaming/WindowedDStream.scala | 39 --- .../spark/streaming/dstream/CoGroupedDStream.scala | 39 +++ .../streaming/dstream/ConstantInputDStream.scala | 19 ++ .../spark/streaming/dstream/DataHandler.scala | 83 +++++++ .../spark/streaming/dstream/FileInputDStream.scala | 110 ++++++++ .../spark/streaming/dstream/FilteredDStream.scala | 21 ++ .../streaming/dstream/FlatMapValuedDStream.scala | 20 ++ .../streaming/dstream/FlatMappedDStream.scala | 20 ++ .../streaming/dstream/FlumeInputDStream.scala | 135 ++++++++++ .../spark/streaming/dstream/ForEachDStream.scala | 28 +++ .../spark/streaming/dstream/GlommedDStream.scala | 17 ++ .../spark/streaming/dstream/InputDStream.scala | 19 ++ .../streaming/dstream/KafkaInputDStream.scala | 197 +++++++++++++++ .../streaming/dstream/MapPartitionedDStream.scala | 21 ++ .../spark/streaming/dstream/MapValuedDStream.scala | 21 ++ .../spark/streaming/dstream/MappedDStream.scala | 20 ++ .../streaming/dstream/NetworkInputDStream.scala | 157 ++++++++++++ .../streaming/dstream/QueueInputDStream.scala | 41 +++ .../spark/streaming/dstream/RawInputDStream.scala | 88 +++++++ .../streaming/dstream/ReducedWindowedDStream.scala | 148 +++++++++++ .../spark/streaming/dstream/ShuffledDStream.scala | 27 ++ .../streaming/dstream/SocketInputDStream.scala | 103 ++++++++ .../spark/streaming/dstream/StateDStream.scala | 83 +++++++ .../streaming/dstream/TransformedDStream.scala | 19 ++ .../spark/streaming/dstream/UnionDStream.scala | 39 +++ .../spark/streaming/dstream/WindowedDStream.scala | 40 +++ .../scala/spark/streaming/examples/GrepRaw.scala | 2 +- .../streaming/examples/TopKWordCountRaw.scala | 2 +- .../spark/streaming/examples/WordCountRaw.scala | 2 +- .../spark/streaming/input/KafkaInputDStream.scala | 193 -------------- .../scala/spark/streaming/CheckpointSuite.scala | 2 +- .../test/scala/spark/streaming/FailureSuite.scala | 2 +- .../scala/spark/streaming/InputStreamsSuite.scala | 1 + .../test/scala/spark/streaming/TestSuiteBase.scala | 48 +++- .../spark/streaming/WindowOperationsSuite.scala | 12 +- 54 files changed, 1600 insertions(+), 1531 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DataHandler.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FileInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/QueueInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/RawInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/SocketInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/StateDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/WindowedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index f40b56be64..1b219473e0 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,6 +1,7 @@ package spark.rdd import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext} +import spark.SparkContext._ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx diff --git a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala deleted file mode 100644 index 61d088eddb..0000000000 --- a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala +++ /dev/null @@ -1,38 +0,0 @@ -package spark.streaming - -import spark.{RDD, Partitioner} -import spark.rdd.CoGroupedRDD - -class CoGroupedDStream[K : ClassManifest]( - parents: Seq[DStream[(_, _)]], - partitioner: Partitioner - ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different StreamingContexts") - } - - if (parents.map(_.slideTime).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideTime = parents.head.slideTime - - override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { - val part = partitioner - val rdds = parents.flatMap(_.getOrCompute(validTime)) - if (rdds.size > 0) { - val q = new CoGroupedRDD[K](rdds, part) - Some(q) - } else { - None - } - } - -} diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala deleted file mode 100644 index 80150708fd..0000000000 --- a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark.streaming - -import spark.RDD - -/** - * An input stream that always returns the same RDD on each timestep. Useful for testing. - */ -class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) - extends InputDStream[T](ssc_) { - - override def start() {} - - override def stop() {} - - override def compute(validTime: Time): Option[RDD[T]] = { - Some(rdd) - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 3834b57ed3..292ad3b9f9 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -1,17 +1,15 @@ package spark.streaming +import spark.streaming.dstream._ import StreamingContext._ import Time._ -import spark._ -import spark.SparkContext._ -import spark.rdd._ +import spark.{RDD, Logging} import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import org.apache.hadoop.fs.Path @@ -197,7 +195,7 @@ abstract class DStream[T: ClassManifest] ( "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + "delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " + "the Java property 'spark.cleaner.delay' to more than " + - math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes." + math.ceil(rememberDuration.milliseconds.toDouble / 60000.0).toInt + " minutes." ) dependencies.foreach(_.validate()) @@ -642,271 +640,3 @@ abstract class DStream[T: ClassManifest] ( private[streaming] case class DStreamCheckpointData(rdds: HashMap[Time, Any]) -abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) - extends DStream[T](ssc_) { - - override def dependencies = List() - - override def slideTime = { - if (ssc == null) throw new Exception("ssc is null") - if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") - ssc.graph.batchDuration - } - - def start() - - def stop() -} - - -/** - * TODO - */ - -private[streaming] -class MappedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], - mapFunc: T => U - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.map[U](mapFunc)) - } -} - - -/** - * TODO - */ - -private[streaming] -class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - flatMapFunc: T => Traversable[U] - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) - } -} - - -/** - * TODO - */ - -private[streaming] -class FilteredDStream[T: ClassManifest]( - parent: DStream[T], - filterFunc: T => Boolean - ) extends DStream[T](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - parent.getOrCompute(validTime).map(_.filter(filterFunc)) - } -} - - -/** - * TODO - */ - -private[streaming] -class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( - parent: DStream[T], - mapPartFunc: Iterator[T] => Iterator[U], - preservePartitioning: Boolean - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) - } -} - - -/** - * TODO - */ - -private[streaming] -class GlommedDStream[T: ClassManifest](parent: DStream[T]) - extends DStream[Array[T]](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Array[T]]] = { - parent.getOrCompute(validTime).map(_.glom()) - } -} - - -/** - * TODO - */ - -private[streaming] -class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( - parent: DStream[(K,V)], - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - partitioner: Partitioner - ) extends DStream [(K,C)] (parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K,C)]] = { - parent.getOrCompute(validTime) match { - case Some(rdd) => - Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) - case None => None - } - } -} - - -/** - * TODO - */ - -private[streaming] -class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( - parent: DStream[(K, V)], - mapValueFunc: V => U - ) extends DStream[(K, U)](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K, U)]] = { - parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) - } -} - - -/** - * TODO - */ -private[streaming] -class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( - parent: DStream[(K, V)], - flatMapValueFunc: V => TraversableOnce[U] - ) extends DStream[(K, U)](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[(K, U)]] = { - parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) - } -} - - - -/** - * TODO - */ - -class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) - extends DStream[T](parents.head.ssc) { - - if (parents.length == 0) { - throw new IllegalArgumentException("Empty array of parents") - } - - if (parents.map(_.ssc).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different StreamingContexts") - } - - if (parents.map(_.slideTime).distinct.size > 1) { - throw new IllegalArgumentException("Array of parents have different slide times") - } - - override def dependencies = parents.toList - - override def slideTime: Time = parents.head.slideTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { - case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) - if (rdds.size > 0) { - Some(new UnionRDD(ssc.sc, rdds)) - } else { - None - } - } -} - - -/** - * TODO - */ - -private[streaming] -class ForEachDStream[T: ClassManifest] ( - parent: DStream[T], - foreachFunc: (RDD[T], Time) => Unit - ) extends DStream[Unit](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[Unit]] = None - - override def generateJob(time: Time): Option[Job] = { - parent.getOrCompute(time) match { - case Some(rdd) => - val jobFunc = () => { - foreachFunc(rdd, time) - } - Some(new Job(time, jobFunc)) - case None => None - } - } -} - - -/** - * TODO - */ - -private[streaming] -class TransformedDStream[T: ClassManifest, U: ClassManifest] ( - parent: DStream[T], - transformFunc: (RDD[T], Time) => RDD[U] - ) extends DStream[U](parent.ssc) { - - override def dependencies = List(parent) - - override def slideTime: Time = parent.slideTime - - override def compute(validTime: Time): Option[RDD[U]] = { - parent.getOrCompute(validTime).map(transformFunc(_, validTime)) - } -} diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index d0a9ade61d..c72429370e 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -1,5 +1,6 @@ package spark.streaming +import dstream.InputDStream import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import spark.Logging diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala deleted file mode 100644 index 05f307a8d1..0000000000 --- a/streaming/src/main/scala/spark/streaming/DataHandler.scala +++ /dev/null @@ -1,83 +0,0 @@ -package spark.streaming - -import java.util.concurrent.ArrayBlockingQueue -import scala.collection.mutable.ArrayBuffer -import spark.Logging -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - - -/** - * This is a helper object that manages the data received from the socket. It divides - * the object received into small batches of 100s of milliseconds, pushes them as - * blocks into the block manager and reports the block IDs to the network input - * tracker. It starts two threads, one to periodically start a new batch and prepare - * the previous batch of as a block, the other to push the blocks into the block - * manager. - */ - class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) - extends Serializable with Logging { - - case class Block(id: String, iterator: Iterator[T], metadata: Any = null) - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def createBlock(blockId: String, iterator: Iterator[T]) : Block = { - new Block(blockId, iterator) - } - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) - val newBlock = createBlock(blockId, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - receiver.stop() - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - receiver.stop() - } - } - } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/FileInputDStream.scala deleted file mode 100644 index 88856364d2..0000000000 --- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala +++ /dev/null @@ -1,109 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.UnionRDD - -import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} - -import scala.collection.mutable.HashSet - - -class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( - @transient ssc_ : StreamingContext, - directory: String, - filter: PathFilter = FileInputDStream.defaultPathFilter, - newFilesOnly: Boolean = true) - extends InputDStream[(K, V)](ssc_) { - - @transient private var path_ : Path = null - @transient private var fs_ : FileSystem = null - - var lastModTime = 0L - val lastModTimeFiles = new HashSet[String]() - - def path(): Path = { - if (path_ == null) path_ = new Path(directory) - path_ - } - - def fs(): FileSystem = { - if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) - fs_ - } - - override def start() { - if (newFilesOnly) { - lastModTime = System.currentTimeMillis() - } else { - lastModTime = 0 - } - } - - override def stop() { } - - /** - * Finds the files that were modified since the last time this method was called and makes - * a union RDD out of them. Note that this maintains the list of files that were processed - * in the latest modification time in the previous call to this method. This is because the - * modification time returned by the FileStatus API seems to return times only at the - * granularity of seconds. Hence, new files may have the same modification time as the - * latest modification time in the previous call to this method and the list of files - * maintained is used to filter the one that have been processed. - */ - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - // Create the filter for selecting new files - val newFilter = new PathFilter() { - var latestModTime = 0L - val latestModTimeFiles = new HashSet[String]() - - def accept(path: Path): Boolean = { - if (!filter.accept(path)) { - return false - } else { - val modTime = fs.getFileStatus(path).getModificationTime() - if (modTime < lastModTime){ - return false - } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { - return false - } - if (modTime > latestModTime) { - latestModTime = modTime - latestModTimeFiles.clear() - } - latestModTimeFiles += path.toString - return true - } - } - } - - val newFiles = fs.listStatus(path, newFilter) - logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) - if (newFiles.length > 0) { - // Update the modification time and the files processed for that modification time - if (lastModTime != newFilter.latestModTime) { - lastModTime = newFilter.latestModTime - lastModTimeFiles.clear() - } - lastModTimeFiles ++= newFilter.latestModTimeFiles - } - val newRDD = new UnionRDD(ssc.sc, newFiles.map( - file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) - Some(newRDD) - } -} - -object FileInputDStream { - val defaultPathFilter = new PathFilter with Serializable { - def accept(path: Path): Boolean = { - val file = path.getName() - if (file.startsWith(".") || file.endsWith("_tmp")) { - return false - } else { - return true - } - } - } -} - diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala deleted file mode 100644 index 5ac7e5b08e..0000000000 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ /dev/null @@ -1,130 +0,0 @@ -package spark.streaming - -import java.io.{ObjectInput, ObjectOutput, Externalizable} -import spark.storage.StorageLevel -import org.apache.flume.source.avro.AvroSourceProtocol -import org.apache.flume.source.avro.AvroFlumeEvent -import org.apache.flume.source.avro.Status -import org.apache.avro.ipc.specific.SpecificResponder -import org.apache.avro.ipc.NettyServer -import java.net.InetSocketAddress -import collection.JavaConversions._ -import spark.Utils -import java.nio.ByteBuffer - -class FlumeInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel -) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { - - override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { - new FlumeReceiver(id, host, port, storageLevel) - } -} - -/** - * A wrapper class for AvroFlumeEvent's with a custom serialization format. - * - * This is necessary because AvroFlumeEvent uses inner data structures - * which are not serializable. - */ -class SparkFlumeEvent() extends Externalizable { - var event : AvroFlumeEvent = new AvroFlumeEvent() - - /* De-serialize from bytes. */ - def readExternal(in: ObjectInput) { - val bodyLength = in.readInt() - val bodyBuff = new Array[Byte](bodyLength) - in.read(bodyBuff) - - val numHeaders = in.readInt() - val headers = new java.util.HashMap[CharSequence, CharSequence] - - for (i <- 0 until numHeaders) { - val keyLength = in.readInt() - val keyBuff = new Array[Byte](keyLength) - in.read(keyBuff) - val key : String = Utils.deserialize(keyBuff) - - val valLength = in.readInt() - val valBuff = new Array[Byte](valLength) - in.read(valBuff) - val value : String = Utils.deserialize(valBuff) - - headers.put(key, value) - } - - event.setBody(ByteBuffer.wrap(bodyBuff)) - event.setHeaders(headers) - } - - /* Serialize to bytes. */ - def writeExternal(out: ObjectOutput) { - val body = event.getBody.array() - out.writeInt(body.length) - out.write(body) - - val numHeaders = event.getHeaders.size() - out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders) { - val keyBuff = Utils.serialize(k.toString) - out.writeInt(keyBuff.length) - out.write(keyBuff) - val valBuff = Utils.serialize(v.toString) - out.writeInt(valBuff.length) - out.write(valBuff) - } - } -} - -private[streaming] object SparkFlumeEvent { - def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { - val event = new SparkFlumeEvent - event.event = in - event - } -} - -/** A simple server that implements Flume's Avro protocol. */ -class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { - override def append(event : AvroFlumeEvent) : Status = { - receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) - Status.OK - } - - override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { - events.foreach (event => - receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)) - Status.OK - } -} - -/** A NetworkReceiver which listens for events using the - * Flume Avro interface.*/ -class FlumeReceiver( - streamId: Int, - host: String, - port: Int, - storageLevel: StorageLevel - ) extends NetworkReceiver[SparkFlumeEvent](streamId) { - - lazy val dataHandler = new DataHandler(this, storageLevel) - - protected override def onStart() { - val responder = new SpecificResponder( - classOf[AvroSourceProtocol], new FlumeEventServer(this)); - val server = new NettyServer(responder, new InetSocketAddress(host, port)); - dataHandler.start() - server.start() - logInfo("Flume receiver started") - } - - protected override def onStop() { - dataHandler.stop() - logInfo("Flume receiver stopped") - } - - override def getLocationPreference = Some(host) -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala deleted file mode 100644 index 4bf13dd50c..0000000000 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ /dev/null @@ -1,156 +0,0 @@ -package spark.streaming - -import scala.collection.mutable.ArrayBuffer - -import spark.{Logging, SparkEnv, RDD} -import spark.rdd.BlockRDD -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - -import java.nio.ByteBuffer - -import akka.actor.{Props, Actor} -import akka.pattern.ask -import akka.dispatch.Await -import akka.util.duration._ - -abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) - extends InputDStream[T](ssc_) { - - // This is an unique identifier that is used to match the network receiver with the - // corresponding network input stream. - val id = ssc.getNewNetworkStreamId() - - /** - * This method creates the receiver object that will be sent to the workers - * to receive data. This method needs to defined by any specific implementation - * of a NetworkInputDStream. - */ - def createReceiver(): NetworkReceiver[T] - - // Nothing to start or stop as both taken care of by the NetworkInputTracker. - def start() {} - - def stop() {} - - override def compute(validTime: Time): Option[RDD[T]] = { - val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) - Some(new BlockRDD[T](ssc.sc, blockIds)) - } -} - - -private[streaming] sealed trait NetworkReceiverMessage -private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage -private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage -private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage - -abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { - - initLogging() - - lazy protected val env = SparkEnv.get - - lazy protected val actor = env.actorSystem.actorOf( - Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) - - lazy protected val receivingThread = Thread.currentThread() - - /** This method will be called to start receiving data. */ - protected def onStart() - - /** This method will be called to stop receiving data. */ - protected def onStop() - - /** This method conveys a placement preference (hostname) for this receiver. */ - def getLocationPreference() : Option[String] = None - - /** - * This method starts the receiver. First is accesses all the lazy members to - * materialize them. Then it calls the user-defined onStart() method to start - * other threads, etc required to receiver the data. - */ - def start() { - try { - // Access the lazy vals to materialize them - env - actor - receivingThread - - // Call user-defined onStart() - onStart() - } catch { - case ie: InterruptedException => - logInfo("Receiving thread interrupted") - //println("Receiving thread interrupted") - case e: Exception => - stopOnError(e) - } - } - - /** - * This method stops the receiver. First it interrupts the main receiving thread, - * that is, the thread that called receiver.start(). Then it calls the user-defined - * onStop() method to stop other threads and/or do cleanup. - */ - def stop() { - receivingThread.interrupt() - onStop() - //TODO: terminate the actor - } - - /** - * This method stops the receiver and reports to exception to the tracker. - * This should be called whenever an exception has happened on any thread - * of the receiver. - */ - protected def stopOnError(e: Exception) { - logError("Error receiving data", e) - stop() - actor ! ReportError(e.toString) - } - - - /** - * This method pushes a block (as iterator of values) into the block manager. - */ - def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { - val buffer = new ArrayBuffer[T] ++ iterator - env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) - - actor ! ReportBlock(blockId, metadata) - } - - /** - * This method pushes a block (as bytes) into the block manager. - */ - def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { - env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId, metadata) - } - - /** A helper actor that communicates with the NetworkInputTracker */ - private class NetworkReceiverActor extends Actor { - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) - val tracker = env.actorSystem.actorFor(url) - val timeout = 5.seconds - - override def preStart() { - val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) - } - - override def receive() = { - case ReportBlock(blockId, metadata) => - tracker ! AddBlocks(streamId, Array(blockId), metadata) - case ReportError(msg) => - tracker ! DeregisterReceiver(streamId, msg) - case StopReceiver(msg) => - stop() - tracker ! DeregisterReceiver(streamId, msg) - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index 658498dfc1..a6ab44271f 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -1,5 +1,7 @@ package spark.streaming +import spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} +import spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} import spark.Logging import spark.SparkEnv diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index f9fef14196..b0a208e67f 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -1,6 +1,9 @@ package spark.streaming import spark.streaming.StreamingContext._ +import spark.streaming.dstream.{ReducedWindowedDStream, StateDStream} +import spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream} +import spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream} import spark.{Manifests, RDD, Partitioner, HashPartitioner} import spark.SparkContext._ @@ -218,13 +221,13 @@ extends Serializable { def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = { - new MapValuesDStream[K, V, U](self, mapValuesFunc) + new MapValuedDStream[K, V, U](self, mapValuesFunc) } def flatMapValues[U: ClassManifest]( flatMapValuesFunc: V => TraversableOnce[U] ): DStream[(K, U)] = { - new FlatMapValuesDStream[K, V, U](self, flatMapValuesFunc) + new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) } def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala deleted file mode 100644 index bb86e51932..0000000000 --- a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala +++ /dev/null @@ -1,40 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.UnionRDD - -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer - -class QueueInputDStream[T: ClassManifest]( - @transient ssc: StreamingContext, - val queue: Queue[RDD[T]], - oneAtATime: Boolean, - defaultRDD: RDD[T] - ) extends InputDStream[T](ssc) { - - override def start() { } - - override def stop() { } - - override def compute(validTime: Time): Option[RDD[T]] = { - val buffer = new ArrayBuffer[RDD[T]]() - if (oneAtATime && queue.size > 0) { - buffer += queue.dequeue() - } else { - buffer ++= queue - } - if (buffer.size > 0) { - if (oneAtATime) { - Some(buffer.first) - } else { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) - } - } else if (defaultRDD != null) { - Some(defaultRDD) - } else { - None - } - } - -} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala deleted file mode 100644 index 6acaa9aab1..0000000000 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ /dev/null @@ -1,85 +0,0 @@ -package spark.streaming - -import java.net.InetSocketAddress -import java.nio.ByteBuffer -import java.nio.channels.{ReadableByteChannel, SocketChannel} -import java.io.EOFException -import java.util.concurrent.ArrayBlockingQueue -import spark._ -import spark.storage.StorageLevel - -/** - * An input stream that reads blocks of serialized objects from a given network address. - * The blocks will be inserted directly into the block store. This is the fastest way to get - * data into Spark Streaming, though it requires the sender to batch data and serialize it - * in the format that the system is configured with. - */ -class RawInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_ ) with Logging { - - def createReceiver(): NetworkReceiver[T] = { - new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] - } -} - -class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) - extends NetworkReceiver[Any](streamId) { - - var blockPushingThread: Thread = null - - override def getLocationPreference = None - - def onStart() { - // Open a socket to the target address and keep reading from it - logInfo("Connecting to " + host + ":" + port) - val channel = SocketChannel.open() - channel.configureBlocking(true) - channel.connect(new InetSocketAddress(host, port)) - logInfo("Connected to " + host + ":" + port) - - val queue = new ArrayBlockingQueue[ByteBuffer](2) - - blockPushingThread = new DaemonThread { - override def run() { - var nextBlockNumber = 0 - while (true) { - val buffer = queue.take() - val blockId = "input-" + streamId + "-" + nextBlockNumber - nextBlockNumber += 1 - pushBlock(blockId, buffer, null, storageLevel) - } - } - } - blockPushingThread.start() - - val lengthBuffer = ByteBuffer.allocate(4) - while (true) { - lengthBuffer.clear() - readFully(channel, lengthBuffer) - lengthBuffer.flip() - val length = lengthBuffer.getInt() - val dataBuffer = ByteBuffer.allocate(length) - readFully(channel, dataBuffer) - dataBuffer.flip() - logInfo("Read a block with " + length + " bytes") - queue.put(dataBuffer) - } - } - - def onStop() { - if (blockPushingThread != null) blockPushingThread.interrupt() - } - - /** Read a buffer fully from a given Channel */ - private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { - while (dest.position < dest.limit) { - if (channel.read(dest) == -1) { - throw new EOFException("End of channel") - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala deleted file mode 100644 index f63a9e0011..0000000000 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ /dev/null @@ -1,149 +0,0 @@ -package spark.streaming - -import spark.streaming.StreamingContext._ - -import spark.RDD -import spark.rdd.UnionRDD -import spark.rdd.CoGroupedRDD -import spark.Partitioner -import spark.SparkContext._ -import spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer -import collection.SeqProxy - -class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( - parent: DStream[(K, V)], - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - _windowTime: Time, - _slideTime: Time, - partitioner: Partitioner - ) extends DStream[(K,V)](parent.ssc) { - - assert(_windowTime.isMultipleOf(parent.slideTime), - "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" - ) - - assert(_slideTime.isMultipleOf(parent.slideTime), - "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" - ) - - // Reduce each batch of data using reduceByKey which will be further reduced by window - // by ReducedWindowedDStream - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) - - // Persist RDDs to memory by default as these RDDs are going to be reused. - super.persist(StorageLevel.MEMORY_ONLY_SER) - reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - - def windowTime: Time = _windowTime - - override def dependencies = List(reducedStream) - - override def slideTime: Time = _slideTime - - override val mustCheckpoint = true - - override def parentRememberDuration: Time = rememberDuration + windowTime - - override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { - super.persist(storageLevel) - reducedStream.persist(storageLevel) - this - } - - override def checkpoint(interval: Time): DStream[(K, V)] = { - super.checkpoint(interval) - //reducedStream.checkpoint(interval) - this - } - - override def compute(validTime: Time): Option[RDD[(K, V)]] = { - val reduceF = reduceFunc - val invReduceF = invReduceFunc - - val currentTime = validTime - val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) - val previousWindow = currentWindow - slideTime - - logDebug("Window time = " + windowTime) - logDebug("Slide time = " + slideTime) - logDebug("ZeroTime = " + zeroTime) - logDebug("Current window = " + currentWindow) - logDebug("Previous window = " + previousWindow) - - // _____________________________ - // | previous window _________|___________________ - // |___________________| current window | --------------> Time - // |_____________________________| - // - // |________ _________| |________ _________| - // | | - // V V - // old RDDs new RDDs - // - - // Get the RDDs of the reduced values in "old time steps" - val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime) - logDebug("# old RDDs = " + oldRDDs.size) - - // Get the RDDs of the reduced values in "new time steps" - val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime) - logDebug("# new RDDs = " + newRDDs.size) - - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) - - // Make the list of RDDs that needs to cogrouped together for reducing their reduced values - val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs - - // Cogroup the reduced RDDs and merge the reduced values - val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) - //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ - - val numOldValues = oldRDDs.size - val numNewValues = newRDDs.size - - val mergeValues = (seqOfValues: Seq[Seq[V]]) => { - if (seqOfValues.size != 1 + numOldValues + numNewValues) { - throw new Exception("Unexpected number of sequences of reduced values") - } - // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) - // Getting reduced values "new time steps" - val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - if (seqOfValues(0).isEmpty) { - // If previous window's reduce value does not exist, then at least new values should exist - if (newValues.isEmpty) { - throw new Exception("Neither previous window has value for key, nor new values found. " + - "Are you sure your key class hashes consistently?") - } - // Reduce the new values - newValues.reduce(reduceF) // return - } else { - // Get the previous window's reduced value - var tempValue = seqOfValues(0).head - // If old values exists, then inverse reduce then from previous value - if (!oldValues.isEmpty) { - tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) - } - // If new values exists, then reduce them with previous value - if (!newValues.isEmpty) { - tempValue = reduceF(tempValue, newValues.reduce(reduceF)) - } - tempValue // return - } - } - - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - - Some(mergedValuesRDD) - } - - -} - - diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index fd1fa77a24..aeb7c3eb0e 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -4,9 +4,6 @@ import util.{ManualClock, RecurringTimer, Clock} import spark.SparkEnv import spark.Logging -import scala.collection.mutable.HashMap - - private[streaming] class Scheduler(ssc: StreamingContext) extends Logging { diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala deleted file mode 100644 index a9e37c0ff0..0000000000 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ /dev/null @@ -1,107 +0,0 @@ -package spark.streaming - -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - -import java.io._ -import java.net.Socket -import java.util.concurrent.ArrayBlockingQueue - -import scala.collection.mutable.ArrayBuffer -import scala.Serializable - -class SocketInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_) { - - def createReceiver(): NetworkReceiver[T] = { - new SocketReceiver(id, host, port, bytesToObjects, storageLevel) - } -} - - -class SocketReceiver[T: ClassManifest]( - streamId: Int, - host: String, - port: Int, - bytesToObjects: InputStream => Iterator[T], - storageLevel: StorageLevel - ) extends NetworkReceiver[T](streamId) { - - lazy protected val dataHandler = new DataHandler(this, storageLevel) - - override def getLocationPreference = None - - protected def onStart() { - logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) - dataHandler.start() - val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - dataHandler += obj - } - } - - protected def onStop() { - dataHandler.stop() - } - -} - - -object SocketReceiver { - - /** - * This methods translates the data from an inputstream (say, from a socket) - * to '\n' delimited strings and returns an iterator to access the strings. - */ - def bytesToLines(inputStream: InputStream): Iterator[String] = { - val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - if (nextValue == null) { - finished = true - } - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - getNext() - if (finished) { - dataInputStream.close() - } - } - } - !finished - } - - override def next(): String = { - if (finished) { - throw new NoSuchElementException("End of stream") - } - if (!gotNext) { - getNext() - } - gotNext = false - nextValue - } - } - iterator - } -} diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala deleted file mode 100644 index b7e4c1c30c..0000000000 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ /dev/null @@ -1,84 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.BlockRDD -import spark.Partitioner -import spark.rdd.MapPartitionsRDD -import spark.SparkContext._ -import spark.storage.StorageLevel - -class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( - parent: DStream[(K, V)], - updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], - partitioner: Partitioner, - preservePartitioning: Boolean - ) extends DStream[(K, S)](parent.ssc) { - - super.persist(StorageLevel.MEMORY_ONLY_SER) - - override def dependencies = List(parent) - - override def slideTime = parent.slideTime - - override val mustCheckpoint = true - - override def compute(validTime: Time): Option[RDD[(K, S)]] = { - - // Try to get the previous state RDD - getOrCompute(validTime - slideTime) match { - - case Some(prevStateRDD) => { // If previous state RDD exists - - // Try to get the parent RDD - parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual - - // Define the function for the mapPartition operation on cogrouped RDD; - // first map the cogrouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { - val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption) - }) - updateFuncLocal(i) - } - val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime) - return Some(stateRDD) - } - case None => { // If parent RDD does not exist, then return old state RDD - return Some(prevStateRDD) - } - } - } - - case None => { // If previous session RDD does not exist (first input data) - - // Try to get the parent RDD - parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual - - // Define the function for the mapPartition operation on grouped RDD; - // first map the grouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { - updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) - } - - val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime + " (first)") - return Some(sessionRDD) - } - case None => { // If parent RDD does not exist, then nothing to do! - //logDebug("Not generating state RDD (no previous state, no parent)") - return None - } - } - } - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 998fea849f..ef73049a81 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -1,10 +1,10 @@ package spark.streaming -import spark.RDD -import spark.Logging -import spark.SparkEnv -import spark.SparkContext +import spark.streaming.dstream._ + +import spark.{RDD, Logging, SparkEnv, SparkContext} import spark.storage.StorageLevel +import spark.util.MetadataCleaner import scala.collection.mutable.Queue @@ -18,7 +18,6 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.flume.source.avro.AvroFlumeEvent import org.apache.hadoop.fs.Path import java.util.UUID -import spark.util.MetadataCleaner /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -126,7 +125,7 @@ class StreamingContext private ( /** * Create an input stream that pulls messages form a Kafka Broker. * - * @param host Zookeper hostname. + * @param hostname Zookeper hostname. * @param port Zookeper port. * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed @@ -319,7 +318,7 @@ object StreamingContext { protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { - time.millis.toString + time.milliseconds.toString } else if (suffix == null || suffix.length ==0) { prefix + "-" + time.milliseconds } else { diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 480d292d7c..2976e5e87b 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,6 +1,11 @@ package spark.streaming -case class Time(millis: Long) { +/** + * This class is simple wrapper class that represents time in UTC. + * @param millis Time in UTC long + */ + +case class Time(private val millis: Long) { def < (that: Time): Boolean = (this.millis < that.millis) @@ -15,7 +20,9 @@ case class Time(millis: Long) { def - (that: Time): Time = Time(millis - that.millis) def * (times: Int): Time = Time(millis * times) - + + def / (that: Time): Long = millis / that.millis + def floor(that: Time): Time = { val t = that.millis val m = math.floor(this.millis / t).toLong diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/WindowedDStream.scala deleted file mode 100644 index e4d2a634f5..0000000000 --- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.streaming - -import spark.RDD -import spark.rdd.UnionRDD -import spark.storage.StorageLevel - - -class WindowedDStream[T: ClassManifest]( - parent: DStream[T], - _windowTime: Time, - _slideTime: Time) - extends DStream[T](parent.ssc) { - - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") - - parent.persist(StorageLevel.MEMORY_ONLY_SER) - - def windowTime: Time = _windowTime - - override def dependencies = List(parent) - - override def slideTime: Time = _slideTime - - override def parentRememberDuration: Time = rememberDuration + windowTime - - override def compute(validTime: Time): Option[RDD[T]] = { - val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) - Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) - } -} - - - diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala new file mode 100644 index 0000000000..2e427dadf7 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala @@ -0,0 +1,39 @@ +package spark.streaming.dstream + +import spark.{RDD, Partitioner} +import spark.rdd.CoGroupedRDD +import spark.streaming.{Time, DStream} + +class CoGroupedDStream[K : ClassManifest]( + parents: Seq[DStream[(_, _)]], + partitioner: Partitioner + ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime = parents.head.slideTime + + override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { + val part = partitioner + val rdds = parents.flatMap(_.getOrCompute(validTime)) + if (rdds.size > 0) { + val q = new CoGroupedRDD[K](rdds, part) + Some(q) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala new file mode 100644 index 0000000000..41c3af4694 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala @@ -0,0 +1,19 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.streaming.{Time, StreamingContext} + +/** + * An input stream that always returns the same RDD on each timestep. Useful for testing. + */ +class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) + extends InputDStream[T](ssc_) { + + override def start() {} + + override def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + Some(rdd) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala b/streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala new file mode 100644 index 0000000000..d737ba1ecc --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/DataHandler.scala @@ -0,0 +1,83 @@ +package spark.streaming.dstream + +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.Logging +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + + +/** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) + extends Serializable with Logging { + + case class Block(id: String, iterator: Iterator[T], metadata: Any = null) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def createBlock(blockId: String, iterator: Iterator[T]) : Block = { + new Block(blockId, iterator) + } + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) + val newBlock = createBlock(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala new file mode 100644 index 0000000000..8cdaff467b --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -0,0 +1,110 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.rdd.UnionRDD +import spark.streaming.{StreamingContext, Time} + +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + +import scala.collection.mutable.HashSet + + +class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( + @transient ssc_ : StreamingContext, + directory: String, + filter: PathFilter = FileInputDStream.defaultPathFilter, + newFilesOnly: Boolean = true) + extends InputDStream[(K, V)](ssc_) { + + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + + var lastModTime = 0L + val lastModTimeFiles = new HashSet[String]() + + def path(): Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + def fs(): FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + + override def start() { + if (newFilesOnly) { + lastModTime = System.currentTimeMillis() + } else { + lastModTime = 0 + } + } + + override def stop() { } + + /** + * Finds the files that were modified since the last time this method was called and makes + * a union RDD out of them. Note that this maintains the list of files that were processed + * in the latest modification time in the previous call to this method. This is because the + * modification time returned by the FileStatus API seems to return times only at the + * granularity of seconds. Hence, new files may have the same modification time as the + * latest modification time in the previous call to this method and the list of files + * maintained is used to filter the one that have been processed. + */ + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + // Create the filter for selecting new files + val newFilter = new PathFilter() { + var latestModTime = 0L + val latestModTimeFiles = new HashSet[String]() + + def accept(path: Path): Boolean = { + if (!filter.accept(path)) { + return false + } else { + val modTime = fs.getFileStatus(path).getModificationTime() + if (modTime < lastModTime){ + return false + } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { + return false + } + if (modTime > latestModTime) { + latestModTime = modTime + latestModTimeFiles.clear() + } + latestModTimeFiles += path.toString + return true + } + } + } + + val newFiles = fs.listStatus(path, newFilter) + logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) + if (newFiles.length > 0) { + // Update the modification time and the files processed for that modification time + if (lastModTime != newFilter.latestModTime) { + lastModTime = newFilter.latestModTime + lastModTimeFiles.clear() + } + lastModTimeFiles ++= newFilter.latestModTimeFiles + } + val newRDD = new UnionRDD(ssc.sc, newFiles.map( + file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) + Some(newRDD) + } +} + +object FileInputDStream { + val defaultPathFilter = new PathFilter with Serializable { + def accept(path: Path): Boolean = { + val file = path.getName() + if (file.startsWith(".") || file.endsWith("_tmp")) { + return false + } else { + return true + } + } + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala new file mode 100644 index 0000000000..1cbb4d536e --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala @@ -0,0 +1,21 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class FilteredDStream[T: ClassManifest]( + parent: DStream[T], + filterFunc: T => Boolean + ) extends DStream[T](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + parent.getOrCompute(validTime).map(_.filter(filterFunc)) + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala new file mode 100644 index 0000000000..11ed8cf317 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -0,0 +1,20 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD +import spark.SparkContext._ + +private[streaming] +class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + flatMapValueFunc: V => TraversableOnce[U] + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala new file mode 100644 index 0000000000..a13b4c9ff9 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala @@ -0,0 +1,20 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + flatMapFunc: T => Traversable[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala new file mode 100644 index 0000000000..7e988cadf4 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala @@ -0,0 +1,135 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext + +import spark.Utils +import spark.storage.StorageLevel + +import org.apache.flume.source.avro.AvroSourceProtocol +import org.apache.flume.source.avro.AvroFlumeEvent +import org.apache.flume.source.avro.Status +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.avro.ipc.NettyServer + +import scala.collection.JavaConversions._ + +import java.net.InetSocketAddress +import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.nio.ByteBuffer + +class FlumeInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel +) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { + + override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { + new FlumeReceiver(id, host, port, storageLevel) + } +} + +/** + * A wrapper class for AvroFlumeEvent's with a custom serialization format. + * + * This is necessary because AvroFlumeEvent uses inner data structures + * which are not serializable. + */ +class SparkFlumeEvent() extends Externalizable { + var event : AvroFlumeEvent = new AvroFlumeEvent() + + /* De-serialize from bytes. */ + def readExternal(in: ObjectInput) { + val bodyLength = in.readInt() + val bodyBuff = new Array[Byte](bodyLength) + in.read(bodyBuff) + + val numHeaders = in.readInt() + val headers = new java.util.HashMap[CharSequence, CharSequence] + + for (i <- 0 until numHeaders) { + val keyLength = in.readInt() + val keyBuff = new Array[Byte](keyLength) + in.read(keyBuff) + val key : String = Utils.deserialize(keyBuff) + + val valLength = in.readInt() + val valBuff = new Array[Byte](valLength) + in.read(valBuff) + val value : String = Utils.deserialize(valBuff) + + headers.put(key, value) + } + + event.setBody(ByteBuffer.wrap(bodyBuff)) + event.setHeaders(headers) + } + + /* Serialize to bytes. */ + def writeExternal(out: ObjectOutput) { + val body = event.getBody.array() + out.writeInt(body.length) + out.write(body) + + val numHeaders = event.getHeaders.size() + out.writeInt(numHeaders) + for ((k, v) <- event.getHeaders) { + val keyBuff = Utils.serialize(k.toString) + out.writeInt(keyBuff.length) + out.write(keyBuff) + val valBuff = Utils.serialize(v.toString) + out.writeInt(valBuff.length) + out.write(valBuff) + } + } +} + +private[streaming] object SparkFlumeEvent { + def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { + val event = new SparkFlumeEvent + event.event = in + event + } +} + +/** A simple server that implements Flume's Avro protocol. */ +class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { + override def append(event : AvroFlumeEvent) : Status = { + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) + Status.OK + } + + override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { + events.foreach (event => + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)) + Status.OK + } +} + +/** A NetworkReceiver which listens for events using the + * Flume Avro interface.*/ +class FlumeReceiver( + streamId: Int, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkReceiver[SparkFlumeEvent](streamId) { + + lazy val dataHandler = new DataHandler(this, storageLevel) + + protected override def onStart() { + val responder = new SpecificResponder( + classOf[AvroSourceProtocol], new FlumeEventServer(this)); + val server = new NettyServer(responder, new InetSocketAddress(host, port)); + dataHandler.start() + server.start() + logInfo("Flume receiver started") + } + + protected override def onStop() { + dataHandler.stop() + logInfo("Flume receiver stopped") + } + + override def getLocationPreference = Some(host) +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala new file mode 100644 index 0000000000..41c629a225 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala @@ -0,0 +1,28 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.streaming.{DStream, Job, Time} + +private[streaming] +class ForEachDStream[T: ClassManifest] ( + parent: DStream[T], + foreachFunc: (RDD[T], Time) => Unit + ) extends DStream[Unit](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Unit]] = None + + override def generateJob(time: Time): Option[Job] = { + parent.getOrCompute(time) match { + case Some(rdd) => + val jobFunc = () => { + foreachFunc(rdd, time) + } + Some(new Job(time, jobFunc)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala new file mode 100644 index 0000000000..92ea503cae --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala @@ -0,0 +1,17 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class GlommedDStream[T: ClassManifest](parent: DStream[T]) + extends DStream[Array[T]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[Array[T]]] = { + parent.getOrCompute(validTime).map(_.glom()) + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala new file mode 100644 index 0000000000..4959c66b06 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala @@ -0,0 +1,19 @@ +package spark.streaming.dstream + +import spark.streaming.{StreamingContext, DStream} + +abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) + extends DStream[T](ssc_) { + + override def dependencies = List() + + override def slideTime = { + if (ssc == null) throw new Exception("ssc is null") + if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") + ssc.graph.batchDuration + } + + def start() + + def stop() +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala new file mode 100644 index 0000000000..a46721af2f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -0,0 +1,197 @@ +package spark.streaming.dstream + +import spark.Logging +import spark.storage.StorageLevel +import spark.streaming.{Time, DStreamCheckpointData, StreamingContext} + +import java.util.Properties +import java.util.concurrent.Executors + +import kafka.consumer._ +import kafka.message.{Message, MessageSet, MessageAndMetadata} +import kafka.serializer.StringDecoder +import kafka.utils.{Utils, ZKGroupTopicDirs} +import kafka.utils.ZkUtils._ + +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ + + +// Key for a specific Kafka Partition: (broker, topic, group, part) +case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) +// NOT USED - Originally intended for fault-tolerance +// Metadata for a Kafka Stream that it sent to the Master +case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) +// NOT USED - Originally intended for fault-tolerance +// Checkpoint data specific to a KafkaInputDstream +case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], + savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) + +/** + * Input stream that pulls messages from a Kafka Broker. + * + * @param host Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * By default the value is pulled from zookeper. + * @param storageLevel RDD storage level. + */ +class KafkaInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + groupId: String, + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + // Metadata that keeps track of which messages have already been consumed. + var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() + + /* NOT USED - Originally intended for fault-tolerance + + // In case of a failure, the offets for a particular timestamp will be restored. + @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null + + + override protected[streaming] def addMetadata(metadata: Any) { + metadata match { + case x : KafkaInputDStreamMetadata => + savedOffsets(x.timestamp) = x.data + // TOOD: Remove logging + logInfo("New saved Offsets: " + savedOffsets) + case _ => logInfo("Received unknown metadata: " + metadata.toString) + } + } + + override protected[streaming] def updateCheckpointData(currentTime: Time) { + super.updateCheckpointData(currentTime) + if(savedOffsets.size > 0) { + // Find the offets that were stored before the checkpoint was initiated + val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last + val latestOffsets = savedOffsets(key) + logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) + checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) + // TODO: This may throw out offsets that are created after the checkpoint, + // but it's unlikely we'll need them. + savedOffsets.clear() + } + } + + override protected[streaming] def restoreCheckpointData() { + super.restoreCheckpointData() + logInfo("Restoring KafkaDStream checkpoint data.") + checkpointData match { + case x : KafkaDStreamCheckpointData => + restoredOffsets = x.savedOffsets + logInfo("Restored KafkaDStream offsets: " + savedOffsets) + } + } */ + + def createReceiver(): NetworkReceiver[T] = { + new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) + .asInstanceOf[NetworkReceiver[T]] + } +} + +class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, + topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { + + // Timeout for establishing a connection to Zookeper in ms. + val ZK_TIMEOUT = 10000 + + // Handles pushing data into the BlockManager + lazy protected val dataHandler = new DataHandler(this, storageLevel) + // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset + lazy val offsets = HashMap[KafkaPartitionKey, Long]() + // Connection to Kafka + var consumerConnector : ZookeeperConsumerConnector = null + + def onStop() { + dataHandler.stop() + } + + def onStart() { + + // Starting the DataHandler that buffers blocks and pushes them into them BlockManager + dataHandler.start() + + // In case we are using multiple Threads to handle Kafka Messages + val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) + + val zooKeeperEndPoint = host + ":" + port + logInfo("Starting Kafka Consumer Stream with group: " + groupId) + logInfo("Initial offsets: " + initialOffsets.toString) + + // Zookeper connection properties + val props = new Properties() + props.put("zk.connect", zooKeeperEndPoint) + props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) + props.put("groupid", groupId) + + // Create the connection to the cluster + logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) + val consumerConfig = new ConsumerConfig(props) + consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] + logInfo("Connected to " + zooKeeperEndPoint) + + // Reset the Kafka offsets in case we are recovering from a failure + resetOffsets(initialOffsets) + + // Create Threads for each Topic/Message Stream we are listening + val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + + // Start the messages handler for each partition + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } + } + + } + + // Overwrites the offets in Zookeper. + private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { + offsets.foreach { case(key, offset) => + val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) + val partitionName = key.brokerId + "-" + key.partId + updatePersistentPath(consumerConnector.zkClient, + topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) + } + } + + // Handles Kafka Messages + private class MessageHandler(stream: KafkaStream[String]) extends Runnable { + def run() { + logInfo("Starting MessageHandler.") + stream.takeWhile { msgAndMetadata => + dataHandler += msgAndMetadata.message + + // Updating the offet. The key is (broker, topic, group, partition). + val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, + groupId, msgAndMetadata.topicInfo.partition.partId) + val offset = msgAndMetadata.topicInfo.getConsumeOffset + offsets.put(key, offset) + // logInfo("Handled message: " + (key, offset).toString) + + // Keep on handling messages + true + } + } + } + + // NOT USED - Originally intended for fault-tolerance + // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) + // extends DataHandler[Any](receiver, storageLevel) { + + // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { + // // Creates a new Block with Kafka-specific Metadata + // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) + // } + + // } + +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala new file mode 100644 index 0000000000..daf78c6893 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala @@ -0,0 +1,21 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( + parent: DStream[T], + mapPartFunc: Iterator[T] => Iterator[U], + preservePartitioning: Boolean + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala new file mode 100644 index 0000000000..689caeef0e --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala @@ -0,0 +1,21 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD +import spark.SparkContext._ + +private[streaming] +class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( + parent: DStream[(K, V)], + mapValueFunc: V => U + ) extends DStream[(K, U)](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K, U)]] = { + parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala new file mode 100644 index 0000000000..786b9966f2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala @@ -0,0 +1,20 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD + +private[streaming] +class MappedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + mapFunc: T => U + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(_.map[U](mapFunc)) + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala new file mode 100644 index 0000000000..41276da8bb --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -0,0 +1,157 @@ +package spark.streaming.dstream + +import spark.streaming.{Time, StreamingContext, AddBlocks, RegisterReceiver, DeregisterReceiver} + +import spark.{Logging, SparkEnv, RDD} +import spark.rdd.BlockRDD +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer + +import java.nio.ByteBuffer + +import akka.actor.{Props, Actor} +import akka.pattern.ask +import akka.dispatch.Await +import akka.util.duration._ + +abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) + extends InputDStream[T](ssc_) { + + // This is an unique identifier that is used to match the network receiver with the + // corresponding network input stream. + val id = ssc.getNewNetworkStreamId() + + /** + * This method creates the receiver object that will be sent to the workers + * to receive data. This method needs to defined by any specific implementation + * of a NetworkInputDStream. + */ + def createReceiver(): NetworkReceiver[T] + + // Nothing to start or stop as both taken care of by the NetworkInputTracker. + def start() {} + + def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + Some(new BlockRDD[T](ssc.sc, blockIds)) + } +} + + +private[streaming] sealed trait NetworkReceiverMessage +private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage +private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage +private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage + +abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { + + initLogging() + + lazy protected val env = SparkEnv.get + + lazy protected val actor = env.actorSystem.actorOf( + Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + + lazy protected val receivingThread = Thread.currentThread() + + /** This method will be called to start receiving data. */ + protected def onStart() + + /** This method will be called to stop receiving data. */ + protected def onStop() + + /** This method conveys a placement preference (hostname) for this receiver. */ + def getLocationPreference() : Option[String] = None + + /** + * This method starts the receiver. First is accesses all the lazy members to + * materialize them. Then it calls the user-defined onStart() method to start + * other threads, etc required to receiver the data. + */ + def start() { + try { + // Access the lazy vals to materialize them + env + actor + receivingThread + + // Call user-defined onStart() + onStart() + } catch { + case ie: InterruptedException => + logInfo("Receiving thread interrupted") + //println("Receiving thread interrupted") + case e: Exception => + stopOnError(e) + } + } + + /** + * This method stops the receiver. First it interrupts the main receiving thread, + * that is, the thread that called receiver.start(). Then it calls the user-defined + * onStop() method to stop other threads and/or do cleanup. + */ + def stop() { + receivingThread.interrupt() + onStop() + //TODO: terminate the actor + } + + /** + * This method stops the receiver and reports to exception to the tracker. + * This should be called whenever an exception has happened on any thread + * of the receiver. + */ + protected def stopOnError(e: Exception) { + logError("Error receiving data", e) + stop() + actor ! ReportError(e.toString) + } + + + /** + * This method pushes a block (as iterator of values) into the block manager. + */ + def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { + val buffer = new ArrayBuffer[T] ++ iterator + env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) + + actor ! ReportBlock(blockId, metadata) + } + + /** + * This method pushes a block (as bytes) into the block manager. + */ + def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { + env.blockManager.putBytes(blockId, bytes, level) + actor ! ReportBlock(blockId, metadata) + } + + /** A helper actor that communicates with the NetworkInputTracker */ + private class NetworkReceiverActor extends Actor { + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + val tracker = env.actorSystem.actorFor(url) + val timeout = 5.seconds + + override def preStart() { + val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) + Await.result(future, timeout) + } + + override def receive() = { + case ReportBlock(blockId, metadata) => + tracker ! AddBlocks(streamId, Array(blockId), metadata) + case ReportError(msg) => + tracker ! DeregisterReceiver(streamId, msg) + case StopReceiver(msg) => + stop() + tracker ! DeregisterReceiver(streamId, msg) + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala new file mode 100644 index 0000000000..024bf3bea4 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala @@ -0,0 +1,41 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.rdd.UnionRDD + +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer +import spark.streaming.{Time, StreamingContext} + +class QueueInputDStream[T: ClassManifest]( + @transient ssc: StreamingContext, + val queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] + ) extends InputDStream[T](ssc) { + + override def start() { } + + override def stop() { } + + override def compute(validTime: Time): Option[RDD[T]] = { + val buffer = new ArrayBuffer[RDD[T]]() + if (oneAtATime && queue.size > 0) { + buffer += queue.dequeue() + } else { + buffer ++= queue + } + if (buffer.size > 0) { + if (oneAtATime) { + Some(buffer.first) + } else { + Some(new UnionRDD(ssc.sc, buffer.toSeq)) + } + } else if (defaultRDD != null) { + Some(defaultRDD) + } else { + None + } + } + +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala new file mode 100644 index 0000000000..996cc7dea8 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -0,0 +1,88 @@ +package spark.streaming.dstream + +import spark.{DaemonThread, Logging} +import spark.storage.StorageLevel +import spark.streaming.StreamingContext + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, SocketChannel} +import java.io.EOFException +import java.util.concurrent.ArrayBlockingQueue + + +/** + * An input stream that reads blocks of serialized objects from a given network address. + * The blocks will be inserted directly into the block store. This is the fastest way to get + * data into Spark Streaming, though it requires the sender to batch data and serialize it + * in the format that the system is configured with. + */ +class RawInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + def createReceiver(): NetworkReceiver[T] = { + new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] + } +} + +class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) + extends NetworkReceiver[Any](streamId) { + + var blockPushingThread: Thread = null + + override def getLocationPreference = None + + def onStart() { + // Open a socket to the target address and keep reading from it + logInfo("Connecting to " + host + ":" + port) + val channel = SocketChannel.open() + channel.configureBlocking(true) + channel.connect(new InetSocketAddress(host, port)) + logInfo("Connected to " + host + ":" + port) + + val queue = new ArrayBlockingQueue[ByteBuffer](2) + + blockPushingThread = new DaemonThread { + override def run() { + var nextBlockNumber = 0 + while (true) { + val buffer = queue.take() + val blockId = "input-" + streamId + "-" + nextBlockNumber + nextBlockNumber += 1 + pushBlock(blockId, buffer, null, storageLevel) + } + } + } + blockPushingThread.start() + + val lengthBuffer = ByteBuffer.allocate(4) + while (true) { + lengthBuffer.clear() + readFully(channel, lengthBuffer) + lengthBuffer.flip() + val length = lengthBuffer.getInt() + val dataBuffer = ByteBuffer.allocate(length) + readFully(channel, dataBuffer) + dataBuffer.flip() + logInfo("Read a block with " + length + " bytes") + queue.put(dataBuffer) + } + } + + def onStop() { + if (blockPushingThread != null) blockPushingThread.interrupt() + } + + /** Read a buffer fully from a given Channel */ + private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { + while (dest.position < dest.limit) { + if (channel.read(dest) == -1) { + throw new EOFException("End of channel") + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala new file mode 100644 index 0000000000..2686de14d2 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -0,0 +1,148 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext._ + +import spark.RDD +import spark.rdd.CoGroupedRDD +import spark.Partitioner +import spark.SparkContext._ +import spark.storage.StorageLevel + +import scala.collection.mutable.ArrayBuffer +import spark.streaming.{Interval, Time, DStream} + +class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( + parent: DStream[(K, V)], + reduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + _windowTime: Time, + _slideTime: Time, + partitioner: Partitioner + ) extends DStream[(K,V)](parent.ssc) { + + assert(_windowTime.isMultipleOf(parent.slideTime), + "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) + + assert(_slideTime.isMultipleOf(parent.slideTime), + "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) + + // Reduce each batch of data using reduceByKey which will be further reduced by window + // by ReducedWindowedDStream + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + + // Persist RDDs to memory by default as these RDDs are going to be reused. + super.persist(StorageLevel.MEMORY_ONLY_SER) + reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) + + def windowTime: Time = _windowTime + + override def dependencies = List(reducedStream) + + override def slideTime: Time = _slideTime + + override val mustCheckpoint = true + + override def parentRememberDuration: Time = rememberDuration + windowTime + + override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + super.persist(storageLevel) + reducedStream.persist(storageLevel) + this + } + + override def checkpoint(interval: Time): DStream[(K, V)] = { + super.checkpoint(interval) + //reducedStream.checkpoint(interval) + this + } + + override def compute(validTime: Time): Option[RDD[(K, V)]] = { + val reduceF = reduceFunc + val invReduceF = invReduceFunc + + val currentTime = validTime + val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) + val previousWindow = currentWindow - slideTime + + logDebug("Window time = " + windowTime) + logDebug("Slide time = " + slideTime) + logDebug("ZeroTime = " + zeroTime) + logDebug("Current window = " + currentWindow) + logDebug("Previous window = " + previousWindow) + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + + // Get the RDDs of the reduced values in "old time steps" + val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime) + logDebug("# old RDDs = " + oldRDDs.size) + + // Get the RDDs of the reduced values in "new time steps" + val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime) + logDebug("# new RDDs = " + newRDDs.size) + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + + // Make the list of RDDs that needs to cogrouped together for reducing their reduced values + val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs + + // Cogroup the reduced RDDs and merge the reduced values + val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) + //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ + + val numOldValues = oldRDDs.size + val numNewValues = newRDDs.size + + val mergeValues = (seqOfValues: Seq[Seq[V]]) => { + if (seqOfValues.size != 1 + numOldValues + numNewValues) { + throw new Exception("Unexpected number of sequences of reduced values") + } + // Getting reduced values "old time steps" that will be removed from current window + val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + // Getting reduced values "new time steps" + val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { + // If previous window's reduce value does not exist, then at least new values should exist + if (newValues.isEmpty) { + throw new Exception("Neither previous window has value for key, nor new values found. " + + "Are you sure your key class hashes consistently?") + } + // Reduce the new values + newValues.reduce(reduceF) // return + } else { + // Get the previous window's reduced value + var tempValue = seqOfValues(0).head + // If old values exists, then inverse reduce then from previous value + if (!oldValues.isEmpty) { + tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) + } + // If new values exists, then reduce them with previous value + if (!newValues.isEmpty) { + tempValue = reduceF(tempValue, newValues.reduce(reduceF)) + } + tempValue // return + } + } + + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) + + Some(mergedValuesRDD) + } + + +} + + diff --git a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala new file mode 100644 index 0000000000..6854bbe665 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala @@ -0,0 +1,27 @@ +package spark.streaming.dstream + +import spark.{RDD, Partitioner} +import spark.SparkContext._ +import spark.streaming.{DStream, Time} + +private[streaming] +class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( + parent: DStream[(K,V)], + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner + ) extends DStream [(K,C)] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[(K,C)]] = { + parent.getOrCompute(validTime) match { + case Some(rdd) => + Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner)) + case None => None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala new file mode 100644 index 0000000000..af5b73ae8d --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala @@ -0,0 +1,103 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext +import spark.storage.StorageLevel + +import java.io._ +import java.net.Socket + +class SocketInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_) { + + def createReceiver(): NetworkReceiver[T] = { + new SocketReceiver(id, host, port, bytesToObjects, storageLevel) + } +} + + +class SocketReceiver[T: ClassManifest]( + streamId: Int, + host: String, + port: Int, + bytesToObjects: InputStream => Iterator[T], + storageLevel: StorageLevel + ) extends NetworkReceiver[T](streamId) { + + lazy protected val dataHandler = new DataHandler(this, storageLevel) + + override def getLocationPreference = None + + protected def onStart() { + logInfo("Connecting to " + host + ":" + port) + val socket = new Socket(host, port) + logInfo("Connected to " + host + ":" + port) + dataHandler.start() + val iterator = bytesToObjects(socket.getInputStream()) + while(iterator.hasNext) { + val obj = iterator.next + dataHandler += obj + } + } + + protected def onStop() { + dataHandler.stop() + } + +} + + +object SocketReceiver { + + /** + * This methods translates the data from an inputstream (say, from a socket) + * to '\n' delimited strings and returns an iterator to access the strings. + */ + def bytesToLines(inputStream: InputStream): Iterator[String] = { + val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) + + val iterator = new Iterator[String] { + var gotNext = false + var finished = false + var nextValue: String = null + + private def getNext() { + try { + nextValue = dataInputStream.readLine() + if (nextValue == null) { + finished = true + } + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + getNext() + if (finished) { + dataInputStream.close() + } + } + } + !finished + } + + override def next(): String = { + if (finished) { + throw new NoSuchElementException("End of stream") + } + if (!gotNext) { + getNext() + } + gotNext = false + nextValue + } + } + iterator + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala new file mode 100644 index 0000000000..6e190b5564 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -0,0 +1,83 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.Partitioner +import spark.SparkContext._ +import spark.storage.StorageLevel +import spark.streaming.{Time, DStream} + +class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( + parent: DStream[(K, V)], + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], + partitioner: Partitioner, + preservePartitioning: Boolean + ) extends DStream[(K, S)](parent.ssc) { + + super.persist(StorageLevel.MEMORY_ONLY_SER) + + override def dependencies = List(parent) + + override def slideTime = parent.slideTime + + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[(K, S)]] = { + + // Try to get the previous state RDD + getOrCompute(validTime - slideTime) match { + + case Some(prevStateRDD) => { // If previous state RDD exists + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val i = iterator.map(t => { + (t._1, t._2._1, t._2._2.headOption) + }) + updateFuncLocal(i) + } + val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) + val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) + //logDebug("Generating state RDD for time " + validTime) + return Some(stateRDD) + } + case None => { // If parent RDD does not exist, then return old state RDD + return Some(prevStateRDD) + } + } + } + + case None => { // If previous session RDD does not exist (first input data) + + // Try to get the parent RDD + parent.getOrCompute(validTime) match { + case Some(parentRDD) => { // If parent RDD exists, then compute as usual + + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) + } + + val groupedRDD = parentRDD.groupByKey(partitioner) + val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) + //logDebug("Generating state RDD for time " + validTime + " (first)") + return Some(sessionRDD) + } + case None => { // If parent RDD does not exist, then nothing to do! + //logDebug("Not generating state RDD (no previous state, no parent)") + return None + } + } + } + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala new file mode 100644 index 0000000000..0337579514 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala @@ -0,0 +1,19 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.streaming.{DStream, Time} + +private[streaming] +class TransformedDStream[T: ClassManifest, U: ClassManifest] ( + parent: DStream[T], + transformFunc: (RDD[T], Time) => RDD[U] + ) extends DStream[U](parent.ssc) { + + override def dependencies = List(parent) + + override def slideTime: Time = parent.slideTime + + override def compute(validTime: Time): Option[RDD[U]] = { + parent.getOrCompute(validTime).map(transformFunc(_, validTime)) + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala new file mode 100644 index 0000000000..f1efb2ae72 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala @@ -0,0 +1,39 @@ +package spark.streaming.dstream + +import spark.streaming.{DStream, Time} +import spark.RDD +import collection.mutable.ArrayBuffer +import spark.rdd.UnionRDD + +class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) + extends DStream[T](parents.head.ssc) { + + if (parents.length == 0) { + throw new IllegalArgumentException("Empty array of parents") + } + + if (parents.map(_.ssc).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different StreamingContexts") + } + + if (parents.map(_.slideTime).distinct.size > 1) { + throw new IllegalArgumentException("Array of parents have different slide times") + } + + override def dependencies = parents.toList + + override def slideTime: Time = parents.head.slideTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val rdds = new ArrayBuffer[RDD[T]]() + parents.map(_.getOrCompute(validTime)).foreach(_ match { + case Some(rdd) => rdds += rdd + case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) + }) + if (rdds.size > 0) { + Some(new UnionRDD(ssc.sc, rdds)) + } else { + None + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala new file mode 100644 index 0000000000..4b2621c497 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala @@ -0,0 +1,40 @@ +package spark.streaming.dstream + +import spark.RDD +import spark.rdd.UnionRDD +import spark.storage.StorageLevel +import spark.streaming.{Interval, Time, DStream} + + +class WindowedDStream[T: ClassManifest]( + parent: DStream[T], + _windowTime: Time, + _slideTime: Time) + extends DStream[T](parent.ssc) { + + if (!_windowTime.isMultipleOf(parent.slideTime)) + throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + if (!_slideTime.isMultipleOf(parent.slideTime)) + throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + + parent.persist(StorageLevel.MEMORY_ONLY_SER) + + def windowTime: Time = _windowTime + + override def dependencies = List(parent) + + override def slideTime: Time = _slideTime + + override def parentRememberDuration: Time = rememberDuration + windowTime + + override def compute(validTime: Time): Option[RDD[T]] = { + val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) + Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) + } +} + + + diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala index 7c4ee3b34c..dfaaf03f03 100644 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -25,7 +25,7 @@ object GrepRaw { val rawStreams = (1 to numStreams).map(_ => ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = new UnionDStream(rawStreams) + val union = ssc.union(rawStreams) union.filter(_.contains("Alice")).count().foreach(r => println("Grep count: " + r.collect().mkString)) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 182dfd8a52..338834bc3c 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -34,7 +34,7 @@ object TopKWordCountRaw { val lines = (1 to numStreams).map(_ => { ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) }) - val union = new UnionDStream(lines.toArray) + val union = ssc.union(lines) val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 9bcd30f4d7..d93335a8ce 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -33,7 +33,7 @@ object WordCountRaw { val lines = (1 to numStreams).map(_ => { ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) }) - val union = new UnionDStream(lines.toArray) + val union = ssc.union(lines) val counts = union.mapPartitions(splitAndCountPartitions) val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) windowedCounts.foreach(r => println("# unique words = " + r.count())) diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala deleted file mode 100644 index 7c642d4802..0000000000 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ /dev/null @@ -1,193 +0,0 @@ -package spark.streaming - -import java.util.Properties -import java.util.concurrent.Executors -import kafka.consumer._ -import kafka.message.{Message, MessageSet, MessageAndMetadata} -import kafka.serializer.StringDecoder -import kafka.utils.{Utils, ZKGroupTopicDirs} -import kafka.utils.ZkUtils._ -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ -import spark._ -import spark.RDD -import spark.storage.StorageLevel - -// Key for a specific Kafka Partition: (broker, topic, group, part) -case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) -// NOT USED - Originally intended for fault-tolerance -// Metadata for a Kafka Stream that it sent to the Master -case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) -// NOT USED - Originally intended for fault-tolerance -// Checkpoint data specific to a KafkaInputDstream -case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], - savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) - -/** - * Input stream that pulls messages form a Kafka Broker. - * - * @param host Zookeper hostname. - * @param port Zookeper port. - * @param groupId The group id for this consumer. - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. - * @param storageLevel RDD storage level. - */ -class KafkaInputDStream[T: ClassManifest]( - @transient ssc_ : StreamingContext, - host: String, - port: Int, - groupId: String, - topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long], - storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_ ) with Logging { - - // Metadata that keeps track of which messages have already been consumed. - var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() - - /* NOT USED - Originally intended for fault-tolerance - - // In case of a failure, the offets for a particular timestamp will be restored. - @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null - - - override protected[streaming] def addMetadata(metadata: Any) { - metadata match { - case x : KafkaInputDStreamMetadata => - savedOffsets(x.timestamp) = x.data - // TOOD: Remove logging - logInfo("New saved Offsets: " + savedOffsets) - case _ => logInfo("Received unknown metadata: " + metadata.toString) - } - } - - override protected[streaming] def updateCheckpointData(currentTime: Time) { - super.updateCheckpointData(currentTime) - if(savedOffsets.size > 0) { - // Find the offets that were stored before the checkpoint was initiated - val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last - val latestOffsets = savedOffsets(key) - logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) - checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) - // TODO: This may throw out offsets that are created after the checkpoint, - // but it's unlikely we'll need them. - savedOffsets.clear() - } - } - - override protected[streaming] def restoreCheckpointData() { - super.restoreCheckpointData() - logInfo("Restoring KafkaDStream checkpoint data.") - checkpointData match { - case x : KafkaDStreamCheckpointData => - restoredOffsets = x.savedOffsets - logInfo("Restored KafkaDStream offsets: " + savedOffsets) - } - } */ - - def createReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) - .asInstanceOf[NetworkReceiver[T]] - } -} - -class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, - topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], - storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { - - // Timeout for establishing a connection to Zookeper in ms. - val ZK_TIMEOUT = 10000 - - // Handles pushing data into the BlockManager - lazy protected val dataHandler = new DataHandler(this, storageLevel) - // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset - lazy val offsets = HashMap[KafkaPartitionKey, Long]() - // Connection to Kafka - var consumerConnector : ZookeeperConsumerConnector = null - - def onStop() { - dataHandler.stop() - } - - def onStart() { - - // Starting the DataHandler that buffers blocks and pushes them into them BlockManager - dataHandler.start() - - // In case we are using multiple Threads to handle Kafka Messages - val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - - val zooKeeperEndPoint = host + ":" + port - logInfo("Starting Kafka Consumer Stream with group: " + groupId) - logInfo("Initial offsets: " + initialOffsets.toString) - - // Zookeper connection properties - val props = new Properties() - props.put("zk.connect", zooKeeperEndPoint) - props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) - props.put("groupid", groupId) - - // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) - val consumerConfig = new ConsumerConfig(props) - consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - logInfo("Connected to " + zooKeeperEndPoint) - - // Reset the Kafka offsets in case we are recovering from a failure - resetOffsets(initialOffsets) - - // Create Threads for each Topic/Message Stream we are listening - val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) - - // Start the messages handler for each partition - topicMessageStreams.values.foreach { streams => - streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } - } - - } - - // Overwrites the offets in Zookeper. - private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { - offsets.foreach { case(key, offset) => - val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) - val partitionName = key.brokerId + "-" + key.partId - updatePersistentPath(consumerConnector.zkClient, - topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) - } - } - - // Handles Kafka Messages - private class MessageHandler(stream: KafkaStream[String]) extends Runnable { - def run() { - logInfo("Starting MessageHandler.") - stream.takeWhile { msgAndMetadata => - dataHandler += msgAndMetadata.message - - // Updating the offet. The key is (broker, topic, group, partition). - val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, - groupId, msgAndMetadata.topicInfo.partition.partId) - val offset = msgAndMetadata.topicInfo.getConsumeOffset - offsets.put(key, offset) - // logInfo("Handled message: " + (key, offset).toString) - - // Keep on handling messages - true - } - } - } - - // NOT USED - Originally intended for fault-tolerance - // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) - // extends DataHandler[Any](receiver, storageLevel) { - - // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { - // // Creates a new Block with Kafka-specific Metadata - // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) - // } - - // } - -} diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 0d82b2f1ea..920388bba9 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -42,7 +42,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val stateStreamCheckpointInterval = Seconds(1) // this ensure checkpointing occurs at least once - val firstNumBatches = (stateStreamCheckpointInterval.millis / batchDuration.millis) * 2 + val firstNumBatches = (stateStreamCheckpointInterval / batchDuration) * 2 val secondNumBatches = firstNumBatches // Setup the streams diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 5b414117fc..4aa428bf64 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -133,7 +133,7 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { // Get the output buffer val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] val output = outputStream.output - val waitTime = (batchDuration.millis * (numBatches.toDouble + 0.5)).toLong + val waitTime = (batchDuration.milliseconds * (numBatches.toDouble + 0.5)).toLong val startTime = System.currentTimeMillis() try { diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index ed9a659092..76b528bec3 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -1,5 +1,6 @@ package spark.streaming +import dstream.SparkFlumeEvent import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.io.{File, BufferedWriter, OutputStreamWriter} import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index a44f738957..28bdd53c3c 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -1,12 +1,16 @@ package spark.streaming +import spark.streaming.dstream.{InputDStream, ForEachDStream} +import spark.streaming.util.ManualClock + import spark.{RDD, Logging} -import util.ManualClock + import collection.mutable.ArrayBuffer -import org.scalatest.FunSuite import collection.mutable.SynchronizedBuffer + import java.io.{ObjectInputStream, IOException} +import org.scalatest.FunSuite /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -70,6 +74,10 @@ trait TestSuiteBase extends FunSuite with Logging { def actuallyWait = false + /** + * Set up required DStreams to test the DStream operation using the two sequences + * of input collections. + */ def setupStreams[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V] @@ -90,6 +98,10 @@ trait TestSuiteBase extends FunSuite with Logging { ssc } + /** + * Set up required DStreams to test the binary operation using the sequence + * of input collections. + */ def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest]( input1: Seq[Seq[U]], input2: Seq[Seq[V]], @@ -173,6 +185,11 @@ trait TestSuiteBase extends FunSuite with Logging { output } + /** + * Verify whether the output values after running a DStream operation + * is same as the expected output values, by comparing the output + * collections either as lists (order matters) or sets (order does not matter) + */ def verifyOutput[V: ClassManifest]( output: Seq[Seq[V]], expectedOutput: Seq[Seq[V]], @@ -199,6 +216,10 @@ trait TestSuiteBase extends FunSuite with Logging { logInfo("Output verified successfully") } + /** + * Test unary DStream operation with a list of inputs, with number of + * batches to run same as the number of expected output values + */ def testOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], @@ -208,6 +229,15 @@ trait TestSuiteBase extends FunSuite with Logging { testOperation[U, V](input, operation, expectedOutput, -1, useSet) } + /** + * Test unary DStream operation with a list of inputs + * @param input Sequence of input collections + * @param operation Binary DStream operation to be applied to the 2 inputs + * @param expectedOutput Sequence of expected output collections + * @param numBatches Number of batches to run the operation for + * @param useSet Compare the output values with the expected output values + * as sets (order matters) or as lists (order does not matter) + */ def testOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], @@ -221,6 +251,10 @@ trait TestSuiteBase extends FunSuite with Logging { verifyOutput[V](output, expectedOutput, useSet) } + /** + * Test binary DStream operation with two lists of inputs, with number of + * batches to run same as the number of expected output values + */ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( input1: Seq[Seq[U]], input2: Seq[Seq[V]], @@ -231,6 +265,16 @@ trait TestSuiteBase extends FunSuite with Logging { testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet) } + /** + * Test binary DStream operation with two lists of inputs + * @param input1 First sequence of input collections + * @param input2 Second sequence of input collections + * @param operation Binary DStream operation to be applied to the 2 inputs + * @param expectedOutput Sequence of expected output collections + * @param numBatches Number of batches to run the operation for + * @param useSet Compare the output values with the expected output values + * as sets (order matters) or as lists (order does not matter) + */ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest]( input1: Seq[Seq[U]], input2: Seq[Seq[V]], diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 3e20e16708..4bc5229465 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -209,7 +209,7 @@ class WindowOperationsSuite extends TestSuiteBase { val expectedOutput = bigGroupByOutput.map(_.map(x => (x._1, x._2.toSet))) val windowTime = Seconds(2) val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.groupByKeyAndWindow(windowTime, slideTime) .map(x => (x._1, x._2.toSet)) @@ -223,7 +223,7 @@ class WindowOperationsSuite extends TestSuiteBase { val expectedOutput = Seq( Seq(1), Seq(2), Seq(3), Seq(3), Seq(1), Seq(0)) val windowTime = Seconds(2) val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[Int]) => s.countByWindow(windowTime, slideTime) testOperation(input, operation, expectedOutput, numBatches, true) } @@ -233,7 +233,7 @@ class WindowOperationsSuite extends TestSuiteBase { val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3))) val windowTime = Seconds(2) val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.countByKeyAndWindow(windowTime, slideTime).map(x => (x._1, x._2.toInt)) } @@ -251,7 +251,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideTime: Time = Seconds(1) ) { test("window - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[Int]) => s.window(windowTime, slideTime) testOperation(input, operation, expectedOutput, numBatches, true) } @@ -265,7 +265,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideTime: Time = Seconds(1) ) { test("reduceByKeyAndWindow - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() } @@ -281,7 +281,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideTime: Time = Seconds(1) ) { test("reduceByKeyAndWindowInv - " + name) { - val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime) .persist() -- cgit v1.2.3 From 18b9b3b99fd753899d19bd10f0dbef7d5c4ae8d7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 30 Dec 2012 20:00:42 -0800 Subject: More classes made private[streaming] to hide from scala docs. --- .gitignore | 2 + .../src/main/scala/spark/streaming/DStream.scala | 2 +- .../main/scala/spark/streaming/JobManager.scala | 2 +- .../src/main/scala/spark/streaming/Scheduler.scala | 2 +- .../scala/spark/streaming/StreamingContext.scala | 108 ++++++++++++++++----- .../src/main/scala/spark/streaming/Time.scala | 30 ++++-- .../spark/streaming/dstream/CoGroupedDStream.scala | 1 + .../spark/streaming/dstream/FileInputDStream.scala | 2 +- .../streaming/dstream/FlumeInputDStream.scala | 13 ++- .../streaming/dstream/KafkaInputDStream.scala | 6 +- .../spark/streaming/dstream/RawInputDStream.scala | 2 + .../streaming/dstream/ReducedWindowedDStream.scala | 1 + .../streaming/dstream/SocketInputDStream.scala | 5 +- .../spark/streaming/dstream/StateDStream.scala | 1 + .../spark/streaming/dstream/UnionDStream.scala | 1 + .../spark/streaming/dstream/WindowedDStream.scala | 2 +- .../main/scala/spark/streaming/util/Clock.scala | 8 +- .../spark/streaming/util/RecurringTimer.scala | 1 + 18 files changed, 137 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index c207409e3c..88d7b56181 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ third_party/libmesos.so third_party/libmesos.dylib conf/java-opts conf/spark-env.sh +conf/streaming-env.sh conf/log4j.properties docs/_site docs/api @@ -31,4 +32,5 @@ project/plugins/src_managed/ logs/ log/ spark-tests.log +streaming-tests.log dependency-reduced-pom.xml diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 292ad3b9f9..beba9cfd4f 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -189,7 +189,7 @@ abstract class DStream[T: ClassManifest] ( val metadataCleanerDelay = spark.util.MetadataCleaner.getDelaySeconds logInfo("metadataCleanupDelay = " + metadataCleanerDelay) assert( - metadataCleanerDelay < 0 || rememberDuration < metadataCleanerDelay * 1000, + metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, "It seems you are doing some DStream window operation or setting a checkpoint interval " + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index fda7264a27..3b910538e0 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -14,7 +14,7 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { try { val timeTaken = job.run() logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format( - (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0)) + (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, timeTaken / 1000.0)) } catch { case e: Exception => logError("Running " + job + " failed", e) diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index aeb7c3eb0e..eb40affe6d 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -22,7 +22,7 @@ class Scheduler(ssc: StreamingContext) extends Logging { val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] - val timer = new RecurringTimer(clock, ssc.graph.batchDuration, generateRDDs(_)) + val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, generateRDDs(_)) def start() { // If context was started from checkpoint, then restart timer such that diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index ef73049a81..7256e41af9 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -15,7 +15,6 @@ import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -import org.apache.flume.source.avro.AvroFlumeEvent import org.apache.hadoop.fs.Path import java.util.UUID @@ -101,14 +100,27 @@ class StreamingContext private ( protected[streaming] var receiverJobThread: Thread = null protected[streaming] var scheduler: Scheduler = null + /** + * Sets each DStreams in this context to remember RDDs it generated in the last given duration. + * DStreams remember RDDs only for a limited duration of time and releases them for garbage + * collection. This method allows the developer to specify how to long to remember the RDDs ( + * if the developer wishes to query old data outside the DStream computation). + * @param duration Minimum duration that each DStream should remember its RDDs + */ def remember(duration: Time) { graph.remember(duration) } - def checkpoint(dir: String, interval: Time = null) { - if (dir != null) { - sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(dir)) - checkpointDir = dir + /** + * Sets the context to periodically checkpoint the DStream operations for master + * fault-tolerance. By default, the graph will be checkpointed every batch interval. + * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored + * @param interval checkpoint interval + */ + def checkpoint(directory: String, interval: Time = null) { + if (directory != null) { + sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory)) + checkpointDir = directory checkpointInterval = interval } else { checkpointDir = null @@ -122,9 +134,8 @@ class StreamingContext private ( protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() - /** + /** * Create an input stream that pulls messages form a Kafka Broker. - * * @param hostname Zookeper hostname. * @param port Zookeper port. * @param groupId The group id for this consumer. @@ -147,6 +158,15 @@ class StreamingContext private ( inputStream } + /** + * Create a input stream from network source hostname:port. Data is received using + * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited + * lines. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param storageLevel Storage level to use for storing the received objects + * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + */ def networkTextStream( hostname: String, port: Int, @@ -155,6 +175,16 @@ class StreamingContext private ( networkStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) } + /** + * Create a input stream from network source hostname:port. Data is received using + * a TCP socket and the receive bytes it interepreted as object using the given + * converter. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param converter Function to convert the byte stream to objects + * @param storageLevel Storage level to use for storing the received objects + * @tparam T Type of the objects received (after converting bytes to objects) + */ def networkStream[T: ClassManifest]( hostname: String, port: Int, @@ -166,16 +196,32 @@ class StreamingContext private ( inputStream } + /** + * Creates a input stream from a Flume source. + * @param hostname Hostname of the slave machine to which the flume data will be sent + * @param port Port of the slave machine to which the flume data will be sent + * @param storageLevel Storage level to use for storing the received objects + */ def flumeStream ( - hostname: String, - port: Int, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2): DStream[SparkFlumeEvent] = { + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[SparkFlumeEvent] = { val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel) graph.addInputStream(inputStream) inputStream } - + /** + * Create a input stream from network source hostname:port, where data is received + * as serialized blocks (serialized using the Spark's serializer) that can be directly + * pushed into the block manager without deserializing them. This is the most efficient + * way to receive data. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param storageLevel Storage level to use for storing the received objects + * @tparam T Type of the objects in the received blocks + */ def rawNetworkStream[T: ClassManifest]( hostname: String, port: Int, @@ -188,7 +234,11 @@ class StreamingContext private ( /** * Creates a input stream that monitors a Hadoop-compatible filesystem - * for new files and executes the necessary processing on them. + * for new files and reads them using the given key-value types and input format. + * @param directory HDFS directory to monitor for new file + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file */ def fileStream[ K: ClassManifest, @@ -200,13 +250,23 @@ class StreamingContext private ( inputStream } + /** + * Creates a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as text files (using key as LongWritable, value + * as Text and input format as TextInputFormat). + * @param directory HDFS directory to monitor for new file + */ def textFileStream(directory: String): DStream[String] = { fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } /** * Creates a input stream from an queue of RDDs. In each batch, - * it will process either one or all of the RDDs returned by the queue + * it will process either one or all of the RDDs returned by the queue. + * @param queue Queue of RDDs + * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval + * @param defaultRDD Default RDD is returned by the DStream when the queue is empty + * @tparam T Type of objects in the RDD */ def queueStream[T: ClassManifest]( queue: Queue[RDD[T]], @@ -218,13 +278,9 @@ class StreamingContext private ( inputStream } - def queueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = { - val queue = new Queue[RDD[T]] - val inputStream = queueStream(queue, true, null) - queue ++= array - inputStream - } - + /** + * Create a unified DStream from multiple DStreams of the same type and same interval + */ def union[T: ClassManifest](streams: Seq[DStream[T]]): DStream[T] = { new UnionDStream[T](streams.toArray) } @@ -256,7 +312,7 @@ class StreamingContext private ( } /** - * This function starts the execution of the streams. + * Starts the execution of the streams. */ def start() { if (checkpointDir != null && checkpointInterval == null && graph != null) { @@ -284,7 +340,7 @@ class StreamingContext private ( } /** - * This function stops the execution of the streams. + * Sstops the execution of the streams. */ def stop() { try { @@ -302,6 +358,10 @@ class StreamingContext private ( object StreamingContext { + implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { + new PairDStreamFunctions[K, V](stream) + } + protected[streaming] def createNewSparkContext(master: String, frameworkName: String): SparkContext = { // Set the default cleaner delay to an hour if not already set. @@ -312,10 +372,6 @@ object StreamingContext { new SparkContext(master, frameworkName) } - implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = { - new PairDStreamFunctions[K, V](stream) - } - protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { time.milliseconds.toString diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 2976e5e87b..3c6fd5d967 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,16 +1,18 @@ package spark.streaming /** - * This class is simple wrapper class that represents time in UTC. - * @param millis Time in UTC long + * This is a simple class that represents time. Internally, it represents time as UTC. + * The recommended way to create instances of Time is to use helper objects + * [[spark.streaming.Milliseconds]], [[spark.streaming.Seconds]], and [[spark.streaming.Minutes]]. + * @param millis Time in UTC. */ case class Time(private val millis: Long) { def < (that: Time): Boolean = (this.millis < that.millis) - + def <= (that: Time): Boolean = (this.millis <= that.millis) - + def > (that: Time): Boolean = (this.millis > that.millis) def >= (that: Time): Boolean = (this.millis >= that.millis) @@ -45,23 +47,33 @@ case class Time(private val millis: Long) { def milliseconds: Long = millis } -object Time { +private[streaming] object Time { val zero = Time(0) implicit def toTime(long: Long) = Time(long) - - implicit def toLong(time: Time) = time.milliseconds } +/** + * Helper object that creates instance of [[spark.streaming.Time]] representing + * a given number of milliseconds. + */ object Milliseconds { def apply(milliseconds: Long) = Time(milliseconds) } +/** + * Helper object that creates instance of [[spark.streaming.Time]] representing + * a given number of seconds. + */ object Seconds { def apply(seconds: Long) = Time(seconds * 1000) -} +} -object Minutes { +/** + * Helper object that creates instance of [[spark.streaming.Time]] representing + * a given number of minutes. + */ +object Minutes { def apply(minutes: Long) = Time(minutes * 60000) } diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala index 2e427dadf7..bc23d423d3 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala @@ -4,6 +4,7 @@ import spark.{RDD, Partitioner} import spark.rdd.CoGroupedRDD import spark.streaming.{Time, DStream} +private[streaming] class CoGroupedDStream[K : ClassManifest]( parents: Seq[DStream[(_, _)]], partitioner: Partitioner diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala index 8cdaff467b..cf72095324 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -10,7 +10,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import scala.collection.mutable.HashSet - +private[streaming] class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( @transient ssc_ : StreamingContext, directory: String, diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala index 7e988cadf4..ff73225e0f 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala @@ -17,6 +17,7 @@ import java.net.InetSocketAddress import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer +private[streaming] class FlumeInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, host: String, @@ -93,6 +94,7 @@ private[streaming] object SparkFlumeEvent { } /** A simple server that implements Flume's Avro protocol. */ +private[streaming] class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { override def append(event : AvroFlumeEvent) : Status = { receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) @@ -108,12 +110,13 @@ class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { /** A NetworkReceiver which listens for events using the * Flume Avro interface.*/ +private[streaming] class FlumeReceiver( - streamId: Int, - host: String, - port: Int, - storageLevel: StorageLevel - ) extends NetworkReceiver[SparkFlumeEvent](streamId) { + streamId: Int, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkReceiver[SparkFlumeEvent](streamId) { lazy val dataHandler = new DataHandler(this, storageLevel) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index a46721af2f..175c75bcb9 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -21,10 +21,12 @@ import scala.collection.JavaConversions._ case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) // NOT USED - Originally intended for fault-tolerance // Metadata for a Kafka Stream that it sent to the Master +private[streaming] case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) // NOT USED - Originally intended for fault-tolerance // Checkpoint data specific to a KafkaInputDstream -case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], +private[streaming] +case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) /** @@ -39,6 +41,7 @@ case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], * By default the value is pulled from zookeper. * @param storageLevel RDD storage level. */ +private[streaming] class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, host: String, @@ -98,6 +101,7 @@ class KafkaInputDStream[T: ClassManifest]( } } +private[streaming] class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala index 996cc7dea8..aa2f31cea8 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -17,6 +17,7 @@ import java.util.concurrent.ArrayBlockingQueue * data into Spark Streaming, though it requires the sender to batch data and serialize it * in the format that the system is configured with. */ +private[streaming] class RawInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, host: String, @@ -29,6 +30,7 @@ class RawInputDStream[T: ClassManifest]( } } +private[streaming] class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala index 2686de14d2..d289ed2a3f 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -11,6 +11,7 @@ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer import spark.streaming.{Interval, Time, DStream} +private[streaming] class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( parent: DStream[(K, V)], reduceFunc: (V, V) => V, diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala index af5b73ae8d..cbe4372299 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala @@ -6,6 +6,7 @@ import spark.storage.StorageLevel import java.io._ import java.net.Socket +private[streaming] class SocketInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, host: String, @@ -19,7 +20,7 @@ class SocketInputDStream[T: ClassManifest]( } } - +private[streaming] class SocketReceiver[T: ClassManifest]( streamId: Int, host: String, @@ -50,7 +51,7 @@ class SocketReceiver[T: ClassManifest]( } - +private[streaming] object SocketReceiver { /** diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala index 6e190b5564..175b3060c1 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -6,6 +6,7 @@ import spark.SparkContext._ import spark.storage.StorageLevel import spark.streaming.{Time, DStream} +private[streaming] class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], diff --git a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala index f1efb2ae72..3bf4c2ecea 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala @@ -5,6 +5,7 @@ import spark.RDD import collection.mutable.ArrayBuffer import spark.rdd.UnionRDD +private[streaming] class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) extends DStream[T](parents.head.ssc) { diff --git a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala index 4b2621c497..7718794cbf 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala @@ -5,7 +5,7 @@ import spark.rdd.UnionRDD import spark.storage.StorageLevel import spark.streaming.{Interval, Time, DStream} - +private[streaming] class WindowedDStream[T: ClassManifest]( parent: DStream[T], _windowTime: Time, diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala index ed087e4ea8..974651f9f6 100644 --- a/streaming/src/main/scala/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/spark/streaming/util/Clock.scala @@ -1,13 +1,12 @@ package spark.streaming.util -import spark.streaming._ - -trait Clock { +private[streaming] +trait Clock { def currentTime(): Long def waitTillTime(targetTime: Long): Long } - +private[streaming] class SystemClock() extends Clock { val minPollTime = 25L @@ -54,6 +53,7 @@ class SystemClock() extends Clock { } } +private[streaming] class ManualClock() extends Clock { var time = 0L diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index dc55fd902b..2e7f4169c9 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -1,5 +1,6 @@ package spark.streaming.util +private[streaming] class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { val minPollTime = 25L -- cgit v1.2.3 From 6f6a6b79c4c3f3555f8ff427c91e714d02afe8fa Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 31 Dec 2012 14:56:23 -0800 Subject: Launch with `scala` by default in run-pyspark --- pyspark/run-pyspark | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyspark/run-pyspark b/pyspark/run-pyspark index f8039b8038..4d10fbea8b 100755 --- a/pyspark/run-pyspark +++ b/pyspark/run-pyspark @@ -20,4 +20,9 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH=$SPARK_HOME/pyspark/:$PYTHONPATH +# Launch with `scala` by default: +if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then + export SPARK_LAUNCH_WITH_SCALA=1 +fi + exec "$PYSPARK_PYTHON" "$@" -- cgit v1.2.3 From f803953998d6b931b266c69acab97b3ece628713 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 30 Dec 2012 12:43:06 -0800 Subject: Raise exception when hashing Java arrays (SPARK-597) --- core/src/main/scala/spark/PairRDDFunctions.scala | 27 +++++++++++++++++++++++ core/src/main/scala/spark/Partitioner.scala | 4 ++++ core/src/main/scala/spark/RDD.scala | 6 +++++ core/src/test/scala/spark/PartitioningSuite.scala | 21 ++++++++++++++++++ 4 files changed, 58 insertions(+) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index d3e206b353..413c944a66 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -52,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true): RDD[(K, C)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (mapSideCombine) { @@ -92,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { + + if (getKeyClass().isArray) { + throw new SparkException("reduceByKeyLocally() does not support array keys") + } + def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { val map = new JHashMap[K, V] for ((k, v) <- iter) { @@ -165,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * be set to true. */ def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } if (mapSideCombine) { def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v @@ -336,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]), partitioner) @@ -352,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], other1.asInstanceOf[RDD[(_, _)]], diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index b71021a082..9d5b966e1e 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable { /** * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. + * + * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, + * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will + * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { def numPartitions = partitions diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index d15c6f7396..7e38583391 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -417,6 +417,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): Map[T, Long] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValue() does not support arrays") + } // TODO: This should perhaps be distributed by default. def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { val map = new OLMap[T] @@ -445,6 +448,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial timeout: Long, confidence: Double = 0.95 ): PartialResult[Map[T, BoundedDouble]] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValueApprox() does not support arrays") + } val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => val map = new OLMap[T] while (iter.hasNext) { diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index 3dadc7acec..f09b602a7b 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -107,4 +107,25 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter { assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) } + + test("partitioning Java arrays should fail") { + sc = new SparkContext("local", "test") + val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) + val arrPairs: RDD[(Array[Int], Int)] = + sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) + + assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) + // We can't catch all usages of arrays, since they might occur inside other collections: + //assert(fails { arrPairs.distinct() }) + assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) + } } -- cgit v1.2.3 From feadaf72f44e7c66521c03171592671d4c441bda Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 31 Dec 2012 14:05:11 -0800 Subject: Mark key as not loading in CacheTracker even when compute() fails --- core/src/main/scala/spark/CacheTracker.scala | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 3d79078733..c8c4063cad 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -202,26 +202,26 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b loading.add(key) } } - // If we got here, we have to load the split - // Tell the master that we're doing so - //val host = System.getProperty("spark.hostname", Utils.localHostName) - //val future = trackerActor !! AddedToCache(rdd.id, split.index, host) - // TODO: fetch any remote copy of the split that may be available - // TODO: also register a listener for when it unloads - logInfo("Computing partition " + split) - val elements = new ArrayBuffer[Any] - elements ++= rdd.compute(split, context) try { + // If we got here, we have to load the split + // Tell the master that we're doing so + //val host = System.getProperty("spark.hostname", Utils.localHostName) + //val future = trackerActor !! AddedToCache(rdd.id, split.index, host) + // TODO: fetch any remote copy of the split that may be available + // TODO: also register a listener for when it unloads + val elements = new ArrayBuffer[Any] + logInfo("Computing partition " + split) + elements ++= rdd.compute(split, context) // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) //future.apply() // Wait for the reply from the cache tracker + return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { loading.remove(key) loading.notifyAll() } } - return elements.iterator.asInstanceOf[Iterator[T]] } } -- cgit v1.2.3 From 21636ee4faf30126b36ad568753788327e634857 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 1 Jan 2013 07:52:31 -0800 Subject: Test with exception while computing cached RDD. --- core/src/test/scala/spark/RDDSuite.scala | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 08da9a1c4d..45e6c5f840 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -88,6 +88,29 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(rdd.collect().toList === List(1, 2, 3, 4)) } + test("caching with failures") { + sc = new SparkContext("local", "test") + val onlySplit = new Split { override def index: Int = 0 } + var shouldFail = true + val rdd = new RDD[Int](sc) { + override def splits: Array[Split] = Array(onlySplit) + override val dependencies = List[Dependency[_]]() + override def compute(split: Split, context: TaskContext): Iterator[Int] = { + if (shouldFail) { + throw new Exception("injected failure") + } else { + return Array(1, 2, 3, 4).iterator + } + } + }.cache() + val thrown = intercept[Exception]{ + rdd.collect() + } + assert(thrown.getMessage.contains("injected failure")) + shouldFail = false + assert(rdd.collect().toList === List(1, 2, 3, 4)) + } + test("coalesced RDDs") { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) -- cgit v1.2.3 From 58072a7340e20251ed810457bc67a79f106bae42 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 1 Jan 2013 07:59:16 -0800 Subject: Remove some dead comments --- core/src/main/scala/spark/CacheTracker.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index c8c4063cad..04c26b2e40 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -204,17 +204,11 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b } try { // If we got here, we have to load the split - // Tell the master that we're doing so - //val host = System.getProperty("spark.hostname", Utils.localHostName) - //val future = trackerActor !! AddedToCache(rdd.id, split.index, host) - // TODO: fetch any remote copy of the split that may be available - // TODO: also register a listener for when it unloads val elements = new ArrayBuffer[Any] logInfo("Computing partition " + split) elements ++= rdd.compute(split, context) // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) - //future.apply() // Wait for the reply from the cache tracker return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { -- cgit v1.2.3 From 02497f0cd49a24ebc8b92d3471de250319fe56cd Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Jan 2013 12:21:32 -0800 Subject: Updated Streaming Programming Guide. --- docs/_layouts/global.html | 12 +- docs/api.md | 5 +- docs/configuration.md | 11 ++ docs/streaming-programming-guide.md | 167 +++++++++++++++++++-- .../spark/streaming/util/RecurringTimer.scala | 1 + 5 files changed, 179 insertions(+), 17 deletions(-) diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index d656b3e3de..a8be52f23e 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -47,11 +47,19 @@
  • Quick Start
  • Scala
  • Java
  • -
  • Spark Streaming (Alpha)
  • +
  • Spark Streaming
  • -
  • API (Scaladoc)
  • +
    + + + + +
    print() Prints the contents of this DStream on the driver. At each interval, this will take at most ten elements from the DStream's RDD and print them. Prints first ten elements of every batch of data in a DStream on the driver.
    saveAsObjectFile(prefix, [suffix]) saveAsObjectFiles(prefix, [suffix]) Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    saveAsTextFile(prefix, suffix) saveAsTextFiles(prefix, [suffix]) Save this DStream's contents as a text files. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    saveAsHadoopFiles(prefix, suffix) saveAsHadoopFiles(prefix, [suffix]) Save this DStream's contents as a Hadoop file. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    spark.cleaner.delay(disable) + Duration (minutes) of how long Spark will remember any metadata (stages generated, tasks generated, etc.). + Periodic cleanups will ensure that metadata older than this duration will be forgetten. This is + useful for running Spark for many hours / days (for example, running 24/7 in case of Spark Streaming + applications). Note that any RDD that persists in memory for more than this duration will be cleared as well. +
    # Configuring Logging diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 7c421ac70f..fc2ea2ef79 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1,8 +1,9 @@ --- layout: global -title: Streaming (Alpha) Programming Guide +title: Spark Streaming Programming Guide --- +* This will become a table of contents (this text will be scraped). {:toc} # Overview @@ -13,33 +14,38 @@ A Spark Streaming application is very similar to a Spark application; it consist This guide shows some how to start programming with DStreams. # Initializing Spark Streaming -The first thing a Spark Streaming program must do is create a `StreamingContext` object, which tells Spark how to access a cluster. A `StreamingContext` can be created from an existing `SparkContext`, or directly: +The first thing a Spark Streaming program must do is create a `StreamingContext` object, which tells Spark how to access a cluster. A `StreamingContext` can be created by using {% highlight scala %} -import spark.SparkContext -import SparkContext._ +new StreamingContext(master, jobName, batchDuration) +{% endhighlight %} + +The `master` parameter is either the [Mesos master URL](running-on-mesos.html) (for running on a cluster)or the special "local" string (for local mode) that is used to create a Spark Context. For more information about this please refer to the [Spark programming guide](scala-programming-guide.html). The `jobName` is the name of the streaming job, which is the same as the jobName used in SparkContext. It is used to identify this job in the Mesos web UI. The `batchDuration` is the size of the batches (as explained earlier). This must be set carefully such the cluster can keep up with the processing of the data streams. Starting with something conservative like 5 seconds maybe a good start. See [Performance Tuning](#setting-the-right-batch-size) section for a detailed discussion. -new StreamingContext(master, frameworkName, batchDuration) +This constructor creates a SparkContext object using the given `master` and `jobName` parameters. However, if you already have a SparkContext or you need to create a custom SparkContext by specifying list of JARs, then a StreamingContext can be created from the existing SparkContext, by using +{% highlight scala %} new StreamingContext(sparkContext, batchDuration) {% endhighlight %} -The `master` parameter is either the [Mesos master URL](running-on-mesos.html) (for running on a cluster)or the special "local" string (for local mode) that is used to create a Spark Context. For more information about this please refer to the [Spark programming guide](scala-programming-guide.html). -# Creating Input Sources - InputDStreams +# Attaching Input Sources - InputDStreams The StreamingContext is used to creating InputDStreams from input sources: {% highlight scala %} -context.neworkStream(host, port) // Creates a stream that uses a TCP socket to read data from : -context.flumeStream(host, port) // Creates a stream populated by a Flume flow +// Assuming ssc is the StreamingContext +ssc.networkStream(hostname, port) // Creates a stream that uses a TCP socket to read data from hostname:port +ssc.textFileStream(directory) // Creates a stream by monitoring and processing new files in a HDFS directory {% endhighlight %} -A complete list of input sources is available in the [DStream API doc](api/streaming/index.html#spark.streaming.StreamingContext). +A complete list of input sources is available in the [StreamingContext API documentation](api/streaming/index.html#spark.streaming.StreamingContext). Data received from these sources can be processed using DStream operations, which are explained next. + -## DStream Operations + +# DStream Operations Once an input stream has been created, you can transform it using _stream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the stream by writing data out to an external source. -### Transformations +## Transformations DStreams support many of the transformations available on normal Spark RDD's: @@ -132,7 +138,7 @@ Spark Streaming features windowed computations, which allow you to report statis -### Output Operators +## Output Operations When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: @@ -165,3 +171,138 @@ When an output operator is called, it triggers the computation of a stream. Curr
    +## DStream Persistence +Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple DStream operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`. + +Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More information on different persistence levels can be found in [Spark Programming Guide](scala-programming-guide.html#rdd-persistence). + +# Starting the Streaming computation +All the above DStream operations are completely lazy, that is, the operations will start executing only after the context is started by using +{% highlight scala %} +ssc.start() +{% endhighlight %} + +Conversely, the computation can be stopped by using +{% highlight scala %} +ssc.stop() +{% endhighlight %} + +# Example - WordCountNetwork.scala +A good example to start off is the spark.streaming.examples.WordCountNetwork. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in /streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala. + +{% highlight scala %} +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ +... + +// Create the context and set up a network input stream to receive from a host:port +val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) +val lines = ssc.networkTextStream(args(1), args(2).toInt) + +// Split the lines into words, count them, and print some of the counts on the master +val words = lines.flatMap(_.split(" ")) +val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) +wordCounts.print() + +// Start the computation +ssc.start() +{% endhighlight %} + +To run this example on your local machine, you need to first run a Netcat server by using + +{% highlight bash %} +$ nc -lk 9999 +{% endhighlight %} + +Then, in a different terminal, you can start WordCountNetwork by using + +{% highlight bash %} +$ ./run spark.streaming.examples.WordCountNetwork local[2] localhost 9999 +{% endhighlight %} + +This will make WordCountNetwork connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen. + + + + +
    +{% highlight bash %} +# TERMINAL 1 +# RUNNING NETCAT + +$ nc -lk 9999 +hello world + + + + + +... +{% endhighlight %} + +{% highlight bash %} +# TERMINAL 2: RUNNING WordCountNetwork +... +2012-12-31 18:47:10,446 INFO SparkContext: Job finished: run at ThreadPoolExecutor.java:886, took 0.038817 s +------------------------------------------- +Time: 1357008430000 ms +------------------------------------------- +(hello,1) +(world,1) + +2012-12-31 18:47:10,447 INFO JobManager: Total delay: 0.44700 s for job 8 (execution: 0.44000 s) +... +{% endhighlight %} +
    + + + +# Performance Tuning +Getting the best performance of a Spark Streaming application on a cluster requires a bit of tuning. This section explains a number of the parameters and configurations that can tuned to improve the performance of you application. At a high level, you need to consider two things: +
      +
    1. Reducing the processing time of each batch of data by efficiently using cluster resources.
    2. +
    3. Setting the right batch size such that the data processing can keep up with the data ingestion.
    4. +
    + +## Reducing the Processing Time of each Batch +There are a number of optimizations that can be done in Spark to minimize the processing time of each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section highlights some of the most important ones. + +### Level of Parallelism +Cluster resources maybe underutilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is 8. You can pass the level of parallelism as an argument (see the [`spark.PairDStreamFunctions`](api/streaming/index.html#spark.PairDStreamFunctions) documentation), or set the system property `spark.default.parallelism` to change the default. + +### Data Serialization +The overhead of data serialization can be significant, especially when sub-second batch sizes are to be achieved. There are two aspects to it. +* Serialization of RDD data in Spark: Please refer to the detailed discussion on data serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default RDDs are persisted as serialized byte arrays to minimize pauses related to GC. +* Serialization of input data: To ingest external data into Spark, data received as bytes (say, from the network) needs to deserialized from bytes and re-serialized into Spark's serialization format. Hence, the deserialization overhead of input data may be a bottleneck. + +### Task Launching Overheads +If the number of tasks launched per second is high (say, 50 or more per second), then the overhead of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: +* Task Serialization: Using Kryo serialization for serializing tasks can reduced the task sizes, and therefore reduce the time taken to send them to the slaves. +* Execution mode: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details. +These changes may reduce batch processing time by 100s of milliseconds, thus allowing sub-second batch size to be viable. + +## Setting the Right Batch Size +For a Spark Streaming application running on a cluster to be stable, the processing of the data streams must keep up with the rate of ingestion of the data streams. Depending on the type of computation, the batch size used may have significant impact on the rate of ingestion that can be sustained by the Spark Streaming application on a fixed cluster resources. For example, let us consider the earlier WordCountNetwork example. For a particular data rate, the system may be able to keep up with reporting word counts every 2 seconds (i.e., batch size of 2 seconds), but not every 500 milliseconds. + +A good approach to figure out the right batch size for your application is to test it with a conservative batch size (say, 5-10 seconds) and a low data rate. To verify whether the system is able to keep up with data rate, you can check the value of the end-to-end delay experienced by each processed batch (in the Spark master logs, find the line having the phrase "Total delay"). If the delay is maintained to be less than the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the data rate and/or reducing the batch size. Note that momentary increase in the delay due to temporary data rate increases maybe fine as long as the delay reduces back to a low value (i.e., less than batch size). + +## 24/7 Operation +By default, Spark does not forget any of the metadata (RDDs generated, stages processed, etc.). But for a Spark Streaming application to operate 24/7, it is necessary for Spark to do periodic cleanup of it metadata. This can be enabled by setting the Java system property `spark.cleaner.delay` to the number of minutes you want any metadata to persist. For example, setting `spark.cleaner.delay` to 10 would cause Spark periodically cleanup all metadata and persisted RDDs that are older than 10 minutes. Note, that this property needs to be set before the SparkContext is created. + +This value is closely tied with any window operation that is being used. Any window operation would require the input data to be persisted in memory for at least the duration of the window. Hence it is necessary to set the delay to at least the value of the largest window operation used in the Spark Streaming application. If this delay is set too low, the application will throw an exception saying so. + +## Memory Tuning +Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, we highlight a few customizations that are strongly recommended to minimize GC related pauses in Spark Streaming applications and achieving more consistent batch processing times. + +* Default persistence level of DStreams: Unlike RDDs, the default persistence level of DStreams serializes the data in memory (that is, [StorageLevel.MEMORY_ONLY_SER](api/core/index.html#spark.storage.StorageLevel$) for DStream compared to [StorageLevel.MEMORY_ONLY](api/core/index.html#spark.storage.StorageLevel$) for RDDs). Even though keeping the data serialized incurs a higher serialization overheads, it significantly reduces GC pauses. + +* Concurrent garbage collector: Using the concurrent mark-and-sweep GC further minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times. + +# Master Fault-tolerance (Alpha) +TODO + +* Checkpointing of DStream graph + +* Recovery from master faults + +* Current state and future directions \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index 2e7f4169c9..db715cc295 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -54,6 +54,7 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } } +private[streaming] object RecurringTimer { def main(args: Array[String]) { -- cgit v1.2.3 From 170e451fbdd308ae77065bd9c0f2bd278abf0cb7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Jan 2013 13:52:14 -0800 Subject: Minor documentation and style fixes for PySpark. --- .../scala/spark/api/python/PythonPartitioner.scala | 4 +- .../main/scala/spark/api/python/PythonRDD.scala | 43 +++++++++++----- docs/index.md | 8 ++- docs/python-programming-guide.md | 3 +- pyspark/examples/kmeans.py | 13 +++-- pyspark/examples/logistic_regression.py | 57 ++++++++++++++++++++++ pyspark/examples/lr.py | 57 ---------------------- pyspark/examples/pi.py | 5 +- pyspark/examples/tc.py | 49 ------------------- pyspark/examples/transitive_closure.py | 50 +++++++++++++++++++ pyspark/examples/wordcount.py | 4 +- pyspark/pyspark/__init__.py | 13 ++++- 12 files changed, 172 insertions(+), 134 deletions(-) create mode 100755 pyspark/examples/logistic_regression.py delete mode 100755 pyspark/examples/lr.py delete mode 100644 pyspark/examples/tc.py create mode 100644 pyspark/examples/transitive_closure.py diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 2c829508e5..648d9402b0 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -17,9 +17,9 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends val hashCode = { if (key.isInstanceOf[Array[Byte]]) { Arrays.hashCode(key.asInstanceOf[Array[Byte]]) - } - else + } else { key.hashCode() + } } val mod = hashCode % numPartitions if (mod < 0) { diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index dc48378fdc..19a039e330 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -13,8 +13,12 @@ import spark.rdd.PipedRDD private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) { // Similar to Runtime.exec(), if we are given a single string, split it into words @@ -38,8 +42,8 @@ private[spark] class PythonRDD[T: ClassManifest]( // Add the environmental variables to the process. val currentEnvVars = pb.environment() - envVars.foreach { - case (variable, value) => currentEnvVars.put(variable, value) + for ((variable, value) <- envVars) { + currentEnvVars.put(variable, value) } val proc = pb.start() @@ -116,6 +120,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } +/** + * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. + * This is used by PySpark's shuffle operations. + */ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Array[Byte], Array[Byte])](prev.context) { override def splits = prev.splits @@ -139,6 +147,16 @@ private[spark] object PythonRDD { * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. * The data format is a 32-bit integer representing the pickled object's length (in bytes), * followed by the pickled data. + * + * Pickle module: + * + * http://docs.python.org/2/library/pickle.html + * + * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: + * + * http://hg.python.org/cpython/file/2.6/Lib/pickle.py + * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py + * * @param elem the object to write * @param dOut a data output stream */ @@ -201,15 +219,14 @@ private[spark] object PythonRDD { } private object Pickle { - def b(x: Int): Byte = x.asInstanceOf[Byte] - val PROTO: Byte = b(0x80) - val TWO: Byte = b(0x02) - val BINUNICODE : Byte = 'X' - val STOP : Byte = '.' - val TUPLE2 : Byte = b(0x86) - val EMPTY_LIST : Byte = ']' - val MARK : Byte = '(' - val APPENDS : Byte = 'e' + val PROTO: Byte = 0x80.toByte + val TWO: Byte = 0x02.toByte + val BINUNICODE: Byte = 'X' + val STOP: Byte = '.' + val TUPLE2: Byte = 0x86.toByte + val EMPTY_LIST: Byte = ']' + val MARK: Byte = '(' + val APPENDS: Byte = 'e' } private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], diff --git a/docs/index.md b/docs/index.md index 33ab58a962..848b585333 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,7 +8,7 @@ TODO(andyk): Rewrite to make the Java API a first class part of the story. {% endcomment %} Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an interpreter. -It provides clean, language-integrated APIs in Scala, Java, and Python, with a rich array of parallel operators. +It provides clean, language-integrated APIs in [Scala](scala-programming-guide.html), [Java](java-programming-guide.html), and [Python](python-programming-guide.html), with a rich array of parallel operators. Spark can run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager, [Hadoop YARN](http://hadoop.apache.org/docs/r2.0.1-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html), Amazon EC2, or without an independent resource manager ("standalone mode"). @@ -61,6 +61,11 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Java Programming Guide](java-programming-guide.html): using Spark from Java * [Python Programming Guide](python-programming-guide.html): using Spark from Python +**API Docs:** + +* [Java/Scala (Scaladoc)](api/core/index.html) +* [Python (Epydoc)](api/pyspark/index.html) + **Deployment guides:** * [Running Spark on Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes @@ -73,7 +78,6 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Configuration](configuration.html): customize Spark via its configuration system * [Tuning Guide](tuning.html): best practices to optimize performance and memory use -* API Docs: [Java/Scala (Scaladoc)](api/core/index.html) and [Python (Epydoc)](api/pyspark/index.html) * [Bagel](bagel-programming-guide.html): an implementation of Google's Pregel on Spark * [Contributing to Spark](contributing-to-spark.html) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index b7c747f905..d88d4eb42d 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -17,8 +17,7 @@ There are a few key differences between the Python and Scala APIs: * Python is dynamically typed, so RDDs can hold objects of different types. * PySpark does not currently support the following Spark features: - Accumulators - - Special functions on RRDs of doubles, such as `mean` and `stdev` - - Approximate jobs / functions, such as `countApprox` and `sumApprox`. + - Special functions on RDDs of doubles, such as `mean` and `stdev` - `lookup` - `mapPartitionsWithSplit` - `persist` at storage levels other than `MEMORY_ONLY` diff --git a/pyspark/examples/kmeans.py b/pyspark/examples/kmeans.py index 9cc366f03c..ad2be21178 100644 --- a/pyspark/examples/kmeans.py +++ b/pyspark/examples/kmeans.py @@ -1,18 +1,21 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" import sys -from pyspark.context import SparkContext -from numpy import array, sum as np_sum +import numpy as np +from pyspark import SparkContext def parseVector(line): - return array([float(x) for x in line.split(' ')]) + return np.array([float(x) for x in line.split(' ')]) def closestPoint(p, centers): bestIndex = 0 closest = float("+inf") for i in range(len(centers)): - tempDist = np_sum((p - centers[i]) ** 2) + tempDist = np.sum((p - centers[i]) ** 2) if tempDist < closest: closest = tempDist bestIndex = i @@ -41,7 +44,7 @@ if __name__ == "__main__": newPoints = pointStats.map( lambda (x, (y, z)): (x, y / z)).collect() - tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) + tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) for (x, y) in newPoints: kPoints[x] = y diff --git a/pyspark/examples/logistic_regression.py b/pyspark/examples/logistic_regression.py new file mode 100755 index 0000000000..f13698a86f --- /dev/null +++ b/pyspark/examples/logistic_regression.py @@ -0,0 +1,57 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from collections import namedtuple +from math import exp +from os.path import realpath +import sys + +import numpy as np +from pyspark import SparkContext + + +N = 100000 # Number of data points +D = 10 # Number of dimensions +R = 0.7 # Scaling factor +ITERATIONS = 5 +np.random.seed(42) + + +DataPoint = namedtuple("DataPoint", ['x', 'y']) +from lr import DataPoint # So that DataPoint is properly serialized + + +def generateData(): + def generatePoint(i): + y = -1 if i % 2 == 0 else 1 + x = np.random.normal(size=D) + (y * R) + return DataPoint(x, y) + return [generatePoint(i) for i in range(N)] + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonLR []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + points = sc.parallelize(generateData(), slices).cache() + + # Initialize w to a random value + w = 2 * np.random.ranf(size=D) - 1 + print "Initial w: " + str(w) + + def add(x, y): + x += y + return x + + for i in range(1, ITERATIONS + 1): + print "On iteration %i" % i + + gradient = points.map(lambda p: + (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x + ).reduce(add) + w -= gradient + + print "Final w: " + str(w) diff --git a/pyspark/examples/lr.py b/pyspark/examples/lr.py deleted file mode 100755 index 5fca0266b8..0000000000 --- a/pyspark/examples/lr.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -This example requires numpy (http://www.numpy.org/) -""" -from collections import namedtuple -from math import exp -from os.path import realpath -import sys - -import numpy as np -from pyspark.context import SparkContext - - -N = 100000 # Number of data points -D = 10 # Number of dimensions -R = 0.7 # Scaling factor -ITERATIONS = 5 -np.random.seed(42) - - -DataPoint = namedtuple("DataPoint", ['x', 'y']) -from lr import DataPoint # So that DataPoint is properly serialized - - -def generateData(): - def generatePoint(i): - y = -1 if i % 2 == 0 else 1 - x = np.random.normal(size=D) + (y * R) - return DataPoint(x, y) - return [generatePoint(i) for i in range(N)] - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonLR []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) - slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 - points = sc.parallelize(generateData(), slices).cache() - - # Initialize w to a random value - w = 2 * np.random.ranf(size=D) - 1 - print "Initial w: " + str(w) - - def add(x, y): - x += y - return x - - for i in range(1, ITERATIONS + 1): - print "On iteration %i" % i - - gradient = points.map(lambda p: - (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x - ).reduce(add) - w -= gradient - - print "Final w: " + str(w) diff --git a/pyspark/examples/pi.py b/pyspark/examples/pi.py index 348bbc5dce..127cba029b 100644 --- a/pyspark/examples/pi.py +++ b/pyspark/examples/pi.py @@ -1,13 +1,14 @@ import sys from random import random from operator import add -from pyspark.context import SparkContext + +from pyspark import SparkContext if __name__ == "__main__": if len(sys.argv) == 1: print >> sys.stderr, \ - "Usage: PythonPi []" + "Usage: PythonPi []" exit(-1) sc = SparkContext(sys.argv[1], "PythonPi") slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 diff --git a/pyspark/examples/tc.py b/pyspark/examples/tc.py deleted file mode 100644 index 9630e72b47..0000000000 --- a/pyspark/examples/tc.py +++ /dev/null @@ -1,49 +0,0 @@ -import sys -from random import Random -from pyspark.context import SparkContext - -numEdges = 200 -numVertices = 100 -rand = Random(42) - - -def generateGraph(): - edges = set() - while len(edges) < numEdges: - src = rand.randrange(0, numEdges) - dst = rand.randrange(0, numEdges) - if src != dst: - edges.add((src, dst)) - return edges - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonTC []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonTC") - slices = sys.argv[2] if len(sys.argv) > 2 else 2 - tc = sc.parallelize(generateGraph(), slices).cache() - - # Linear transitive closure: each round grows paths by one edge, - # by joining the graph's edges with the already-discovered paths. - # e.g. join the path (y, z) from the TC with the edge (x, y) from - # the graph to obtain the path (x, z). - - # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.map(lambda (x, y): (y, x)) - - oldCount = 0L - nextCount = tc.count() - while True: - oldCount = nextCount - # Perform the join, obtaining an RDD of (y, (z, x)) pairs, - # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) - tc = tc.union(new_edges).distinct().cache() - nextCount = tc.count() - if nextCount == oldCount: - break - - print "TC has %i edges" % tc.count() diff --git a/pyspark/examples/transitive_closure.py b/pyspark/examples/transitive_closure.py new file mode 100644 index 0000000000..73f7f8fbaf --- /dev/null +++ b/pyspark/examples/transitive_closure.py @@ -0,0 +1,50 @@ +import sys +from random import Random + +from pyspark import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonTC") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelize(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.map(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/pyspark/examples/wordcount.py b/pyspark/examples/wordcount.py index 8365c070e8..857160624b 100644 --- a/pyspark/examples/wordcount.py +++ b/pyspark/examples/wordcount.py @@ -1,6 +1,8 @@ import sys from operator import add -from pyspark.context import SparkContext + +from pyspark import SparkContext + if __name__ == "__main__": if len(sys.argv) < 3: diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py index 8f8402b62b..1ab360a666 100644 --- a/pyspark/pyspark/__init__.py +++ b/pyspark/pyspark/__init__.py @@ -1,9 +1,20 @@ +""" +PySpark is a Python API for Spark. + +Public classes: + + - L{SparkContext} + Main entry point for Spark functionality. + - L{RDD} + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. +""" import sys import os sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "pyspark/lib/py4j0.7.egg")) from pyspark.context import SparkContext +from pyspark.rdd import RDD -__all__ = ["SparkContext"] +__all__ = ["SparkContext", "RDD"] -- cgit v1.2.3 From b58340dbd9a741331fc4c3829b08c093560056c2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Jan 2013 14:48:45 -0800 Subject: Rename top-level 'pyspark' directory to 'python' --- .../main/scala/spark/api/python/PythonRDD.scala | 2 +- docs/_plugins/copy_api_dirs.rb | 8 +- pyspark-shell | 3 + pyspark/.gitignore | 2 - pyspark/epydoc.conf | 19 - pyspark/examples/kmeans.py | 52 -- pyspark/examples/logistic_regression.py | 57 -- pyspark/examples/pi.py | 21 - pyspark/examples/transitive_closure.py | 50 -- pyspark/examples/wordcount.py | 19 - pyspark/lib/PY4J_LICENSE.txt | 27 - pyspark/lib/PY4J_VERSION.txt | 1 - pyspark/lib/py4j0.7.egg | Bin 191756 -> 0 bytes pyspark/lib/py4j0.7.jar | Bin 103286 -> 0 bytes pyspark/pyspark-shell | 3 - pyspark/pyspark/__init__.py | 20 - pyspark/pyspark/broadcast.py | 48 - pyspark/pyspark/cloudpickle.py | 974 --------------------- pyspark/pyspark/context.py | 158 ---- pyspark/pyspark/java_gateway.py | 38 - pyspark/pyspark/join.py | 92 -- pyspark/pyspark/rdd.py | 713 --------------- pyspark/pyspark/serializers.py | 78 -- pyspark/pyspark/shell.py | 33 - pyspark/pyspark/worker.py | 40 - pyspark/run-pyspark | 28 - python/.gitignore | 2 + python/epydoc.conf | 19 + python/examples/kmeans.py | 52 ++ python/examples/logistic_regression.py | 57 ++ python/examples/pi.py | 21 + python/examples/transitive_closure.py | 50 ++ python/examples/wordcount.py | 19 + python/lib/PY4J_LICENSE.txt | 27 + python/lib/PY4J_VERSION.txt | 1 + python/lib/py4j0.7.egg | Bin 0 -> 191756 bytes python/lib/py4j0.7.jar | Bin 0 -> 103286 bytes python/pyspark/__init__.py | 20 + python/pyspark/broadcast.py | 48 + python/pyspark/cloudpickle.py | 974 +++++++++++++++++++++ python/pyspark/context.py | 158 ++++ python/pyspark/java_gateway.py | 38 + python/pyspark/join.py | 92 ++ python/pyspark/rdd.py | 713 +++++++++++++++ python/pyspark/serializers.py | 78 ++ python/pyspark/shell.py | 33 + python/pyspark/worker.py | 40 + run | 2 +- run-pyspark | 28 + run2.cmd | 2 +- 50 files changed, 2480 insertions(+), 2480 deletions(-) create mode 100755 pyspark-shell delete mode 100644 pyspark/.gitignore delete mode 100644 pyspark/epydoc.conf delete mode 100644 pyspark/examples/kmeans.py delete mode 100755 pyspark/examples/logistic_regression.py delete mode 100644 pyspark/examples/pi.py delete mode 100644 pyspark/examples/transitive_closure.py delete mode 100644 pyspark/examples/wordcount.py delete mode 100644 pyspark/lib/PY4J_LICENSE.txt delete mode 100644 pyspark/lib/PY4J_VERSION.txt delete mode 100644 pyspark/lib/py4j0.7.egg delete mode 100644 pyspark/lib/py4j0.7.jar delete mode 100755 pyspark/pyspark-shell delete mode 100644 pyspark/pyspark/__init__.py delete mode 100644 pyspark/pyspark/broadcast.py delete mode 100644 pyspark/pyspark/cloudpickle.py delete mode 100644 pyspark/pyspark/context.py delete mode 100644 pyspark/pyspark/java_gateway.py delete mode 100644 pyspark/pyspark/join.py delete mode 100644 pyspark/pyspark/rdd.py delete mode 100644 pyspark/pyspark/serializers.py delete mode 100644 pyspark/pyspark/shell.py delete mode 100644 pyspark/pyspark/worker.py delete mode 100755 pyspark/run-pyspark create mode 100644 python/.gitignore create mode 100644 python/epydoc.conf create mode 100644 python/examples/kmeans.py create mode 100755 python/examples/logistic_regression.py create mode 100644 python/examples/pi.py create mode 100644 python/examples/transitive_closure.py create mode 100644 python/examples/wordcount.py create mode 100644 python/lib/PY4J_LICENSE.txt create mode 100644 python/lib/PY4J_VERSION.txt create mode 100644 python/lib/py4j0.7.egg create mode 100644 python/lib/py4j0.7.jar create mode 100644 python/pyspark/__init__.py create mode 100644 python/pyspark/broadcast.py create mode 100644 python/pyspark/cloudpickle.py create mode 100644 python/pyspark/context.py create mode 100644 python/pyspark/java_gateway.py create mode 100644 python/pyspark/join.py create mode 100644 python/pyspark/rdd.py create mode 100644 python/pyspark/serializers.py create mode 100644 python/pyspark/shell.py create mode 100644 python/pyspark/worker.py create mode 100755 run-pyspark diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 19a039e330..cf60d14f03 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -38,7 +38,7 @@ private[spark] class PythonRDD[T: ClassManifest]( override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") - val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) + val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py")) // Add the environmental variables to the process. val currentEnvVars = pb.environment() diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 577f3ebe70..c9ce589c1b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -30,8 +30,8 @@ if ENV['SKIP_SCALADOC'] != '1' end if ENV['SKIP_EPYDOC'] != '1' - puts "Moving to pyspark directory and building epydoc." - cd("../pyspark") + puts "Moving to python directory and building epydoc." + cd("../python") puts `epydoc --config epydoc.conf` puts "Moving back into docs dir." @@ -40,8 +40,8 @@ if ENV['SKIP_EPYDOC'] != '1' puts "echo making directory pyspark" mkdir_p "pyspark" - puts "cp -r ../pyspark/docs/. api/pyspark" - cp_r("../pyspark/docs/.", "api/pyspark") + puts "cp -r ../python/docs/. api/pyspark" + cp_r("../python/docs/.", "api/pyspark") cd("..") end diff --git a/pyspark-shell b/pyspark-shell new file mode 100755 index 0000000000..27aaac3a26 --- /dev/null +++ b/pyspark-shell @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +FWDIR="`dirname $0`" +exec $FWDIR/run-pyspark $FWDIR/python/pyspark/shell.py "$@" diff --git a/pyspark/.gitignore b/pyspark/.gitignore deleted file mode 100644 index 5c56e638f9..0000000000 --- a/pyspark/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -*.pyc -docs/ diff --git a/pyspark/epydoc.conf b/pyspark/epydoc.conf deleted file mode 100644 index 91ac984ba2..0000000000 --- a/pyspark/epydoc.conf +++ /dev/null @@ -1,19 +0,0 @@ -[epydoc] # Epydoc section marker (required by ConfigParser) - -# Information about the project. -name: PySpark -url: http://spark-project.org - -# The list of modules to document. Modules can be named using -# dotted names, module filenames, or package directory names. -# This option may be repeated. -modules: pyspark - -# Write html output to the directory "apidocs" -output: html -target: docs/ - -private: no - -exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers - pyspark.java_gateway pyspark.examples pyspark.shell diff --git a/pyspark/examples/kmeans.py b/pyspark/examples/kmeans.py deleted file mode 100644 index ad2be21178..0000000000 --- a/pyspark/examples/kmeans.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -This example requires numpy (http://www.numpy.org/) -""" -import sys - -import numpy as np -from pyspark import SparkContext - - -def parseVector(line): - return np.array([float(x) for x in line.split(' ')]) - - -def closestPoint(p, centers): - bestIndex = 0 - closest = float("+inf") - for i in range(len(centers)): - tempDist = np.sum((p - centers[i]) ** 2) - if tempDist < closest: - closest = tempDist - bestIndex = i - return bestIndex - - -if __name__ == "__main__": - if len(sys.argv) < 5: - print >> sys.stderr, \ - "Usage: PythonKMeans " - exit(-1) - sc = SparkContext(sys.argv[1], "PythonKMeans") - lines = sc.textFile(sys.argv[2]) - data = lines.map(parseVector).cache() - K = int(sys.argv[3]) - convergeDist = float(sys.argv[4]) - - kPoints = data.takeSample(False, K, 34) - tempDist = 1.0 - - while tempDist > convergeDist: - closest = data.map( - lambda p : (closestPoint(p, kPoints), (p, 1))) - pointStats = closest.reduceByKey( - lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) - newPoints = pointStats.map( - lambda (x, (y, z)): (x, y / z)).collect() - - tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) - - for (x, y) in newPoints: - kPoints[x] = y - - print "Final centers: " + str(kPoints) diff --git a/pyspark/examples/logistic_regression.py b/pyspark/examples/logistic_regression.py deleted file mode 100755 index f13698a86f..0000000000 --- a/pyspark/examples/logistic_regression.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -This example requires numpy (http://www.numpy.org/) -""" -from collections import namedtuple -from math import exp -from os.path import realpath -import sys - -import numpy as np -from pyspark import SparkContext - - -N = 100000 # Number of data points -D = 10 # Number of dimensions -R = 0.7 # Scaling factor -ITERATIONS = 5 -np.random.seed(42) - - -DataPoint = namedtuple("DataPoint", ['x', 'y']) -from lr import DataPoint # So that DataPoint is properly serialized - - -def generateData(): - def generatePoint(i): - y = -1 if i % 2 == 0 else 1 - x = np.random.normal(size=D) + (y * R) - return DataPoint(x, y) - return [generatePoint(i) for i in range(N)] - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonLR []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) - slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 - points = sc.parallelize(generateData(), slices).cache() - - # Initialize w to a random value - w = 2 * np.random.ranf(size=D) - 1 - print "Initial w: " + str(w) - - def add(x, y): - x += y - return x - - for i in range(1, ITERATIONS + 1): - print "On iteration %i" % i - - gradient = points.map(lambda p: - (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x - ).reduce(add) - w -= gradient - - print "Final w: " + str(w) diff --git a/pyspark/examples/pi.py b/pyspark/examples/pi.py deleted file mode 100644 index 127cba029b..0000000000 --- a/pyspark/examples/pi.py +++ /dev/null @@ -1,21 +0,0 @@ -import sys -from random import random -from operator import add - -from pyspark import SparkContext - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonPi []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonPi") - slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 - n = 100000 * slices - def f(_): - x = random() * 2 - 1 - y = random() * 2 - 1 - return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) - print "Pi is roughly %f" % (4.0 * count / n) diff --git a/pyspark/examples/transitive_closure.py b/pyspark/examples/transitive_closure.py deleted file mode 100644 index 73f7f8fbaf..0000000000 --- a/pyspark/examples/transitive_closure.py +++ /dev/null @@ -1,50 +0,0 @@ -import sys -from random import Random - -from pyspark import SparkContext - -numEdges = 200 -numVertices = 100 -rand = Random(42) - - -def generateGraph(): - edges = set() - while len(edges) < numEdges: - src = rand.randrange(0, numEdges) - dst = rand.randrange(0, numEdges) - if src != dst: - edges.add((src, dst)) - return edges - - -if __name__ == "__main__": - if len(sys.argv) == 1: - print >> sys.stderr, \ - "Usage: PythonTC []" - exit(-1) - sc = SparkContext(sys.argv[1], "PythonTC") - slices = sys.argv[2] if len(sys.argv) > 2 else 2 - tc = sc.parallelize(generateGraph(), slices).cache() - - # Linear transitive closure: each round grows paths by one edge, - # by joining the graph's edges with the already-discovered paths. - # e.g. join the path (y, z) from the TC with the edge (x, y) from - # the graph to obtain the path (x, z). - - # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.map(lambda (x, y): (y, x)) - - oldCount = 0L - nextCount = tc.count() - while True: - oldCount = nextCount - # Perform the join, obtaining an RDD of (y, (z, x)) pairs, - # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) - tc = tc.union(new_edges).distinct().cache() - nextCount = tc.count() - if nextCount == oldCount: - break - - print "TC has %i edges" % tc.count() diff --git a/pyspark/examples/wordcount.py b/pyspark/examples/wordcount.py deleted file mode 100644 index 857160624b..0000000000 --- a/pyspark/examples/wordcount.py +++ /dev/null @@ -1,19 +0,0 @@ -import sys -from operator import add - -from pyspark import SparkContext - - -if __name__ == "__main__": - if len(sys.argv) < 3: - print >> sys.stderr, \ - "Usage: PythonWordCount " - exit(-1) - sc = SparkContext(sys.argv[1], "PythonWordCount") - lines = sc.textFile(sys.argv[2], 1) - counts = lines.flatMap(lambda x: x.split(' ')) \ - .map(lambda x: (x, 1)) \ - .reduceByKey(add) - output = counts.collect() - for (word, count) in output: - print "%s : %i" % (word, count) diff --git a/pyspark/lib/PY4J_LICENSE.txt b/pyspark/lib/PY4J_LICENSE.txt deleted file mode 100644 index a70279ca14..0000000000 --- a/pyspark/lib/PY4J_LICENSE.txt +++ /dev/null @@ -1,27 +0,0 @@ - -Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -- Redistributions of source code must retain the above copyright notice, this -list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright notice, -this list of conditions and the following disclaimer in the documentation -and/or other materials provided with the distribution. - -- The name of the author may not be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/pyspark/lib/PY4J_VERSION.txt b/pyspark/lib/PY4J_VERSION.txt deleted file mode 100644 index 04a0cd52a8..0000000000 --- a/pyspark/lib/PY4J_VERSION.txt +++ /dev/null @@ -1 +0,0 @@ -b7924aabe9c5e63f0a4d8bbd17019534c7ec014e diff --git a/pyspark/lib/py4j0.7.egg b/pyspark/lib/py4j0.7.egg deleted file mode 100644 index f8a339d8ee..0000000000 Binary files a/pyspark/lib/py4j0.7.egg and /dev/null differ diff --git a/pyspark/lib/py4j0.7.jar b/pyspark/lib/py4j0.7.jar deleted file mode 100644 index 73b7ddb7d1..0000000000 Binary files a/pyspark/lib/py4j0.7.jar and /dev/null differ diff --git a/pyspark/pyspark-shell b/pyspark/pyspark-shell deleted file mode 100755 index e3736826e8..0000000000 --- a/pyspark/pyspark-shell +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash -FWDIR="`dirname $0`" -exec $FWDIR/run-pyspark $FWDIR/pyspark/shell.py "$@" diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py deleted file mode 100644 index 1ab360a666..0000000000 --- a/pyspark/pyspark/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -PySpark is a Python API for Spark. - -Public classes: - - - L{SparkContext} - Main entry point for Spark functionality. - - L{RDD} - A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. -""" -import sys -import os -sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "pyspark/lib/py4j0.7.egg")) - - -from pyspark.context import SparkContext -from pyspark.rdd import RDD - - -__all__ = ["SparkContext", "RDD"] diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py deleted file mode 100644 index 93876fa738..0000000000 --- a/pyspark/pyspark/broadcast.py +++ /dev/null @@ -1,48 +0,0 @@ -""" ->>> from pyspark.context import SparkContext ->>> sc = SparkContext('local', 'test') ->>> b = sc.broadcast([1, 2, 3, 4, 5]) ->>> b.value -[1, 2, 3, 4, 5] - ->>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.bid] = b ->>> from cPickle import dumps, loads ->>> loads(dumps(b)).value -[1, 2, 3, 4, 5] - ->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() -[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] - ->>> large_broadcast = sc.broadcast(list(range(10000))) -""" -# Holds broadcasted data received from Java, keyed by its id. -_broadcastRegistry = {} - - -def _from_id(bid): - from pyspark.broadcast import _broadcastRegistry - if bid not in _broadcastRegistry: - raise Exception("Broadcast variable '%s' not loaded!" % bid) - return _broadcastRegistry[bid] - - -class Broadcast(object): - def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): - self.value = value - self.bid = bid - self._jbroadcast = java_broadcast - self._pickle_registry = pickle_registry - - def __reduce__(self): - self._pickle_registry.add(self) - return (_from_id, (self.bid, )) - - -def _test(): - import doctest - doctest.testmod() - - -if __name__ == "__main__": - _test() diff --git a/pyspark/pyspark/cloudpickle.py b/pyspark/pyspark/cloudpickle.py deleted file mode 100644 index 6a7c23a069..0000000000 --- a/pyspark/pyspark/cloudpickle.py +++ /dev/null @@ -1,974 +0,0 @@ -""" -This class is defined to override standard pickle functionality - -The goals of it follow: --Serialize lambdas and nested functions to compiled byte code --Deal with main module correctly --Deal with other non-serializable objects - -It does not include an unpickler, as standard python unpickling suffices. - -This module was extracted from the `cloud` package, developed by `PiCloud, Inc. -`_. - -Copyright (c) 2012, Regents of the University of California. -Copyright (c) 2009 `PiCloud, Inc. `_. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of the University of California, Berkeley nor the - names of its contributors may be used to endorse or promote - products derived from this software without specific prior written - permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED -TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" - - -import operator -import os -import pickle -import struct -import sys -import types -from functools import partial -import itertools -from copy_reg import _extension_registry, _inverted_registry, _extension_cache -import new -import dis -import traceback - -#relevant opcodes -STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) -DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) -LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) -GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] - -HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) -EXTENDED_ARG = chr(dis.EXTENDED_ARG) - -import logging -cloudLog = logging.getLogger("Cloud.Transport") - -try: - import ctypes -except (MemoryError, ImportError): - logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) - ctypes = None - PyObject_HEAD = None -else: - - # for reading internal structures - PyObject_HEAD = [ - ('ob_refcnt', ctypes.c_size_t), - ('ob_type', ctypes.c_void_p), - ] - - -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO - -# These helper functions were copied from PiCloud's util module. -def islambda(func): - return getattr(func,'func_name') == '' - -def xrange_params(xrangeobj): - """Returns a 3 element tuple describing the xrange start, step, and len - respectively - - Note: Only guarentees that elements of xrange are the same. parameters may - be different. - e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same - though w/ iteration - """ - - xrange_len = len(xrangeobj) - if not xrange_len: #empty - return (0,1,0) - start = xrangeobj[0] - if xrange_len == 1: #one element - return start, 1, 1 - return (start, xrangeobj[1] - xrangeobj[0], xrange_len) - -#debug variables intended for developer use: -printSerialization = False -printMemoization = False - -useForcedImports = True #Should I use forced imports for tracking? - - - -class CloudPickler(pickle.Pickler): - - dispatch = pickle.Pickler.dispatch.copy() - savedForceImports = False - savedDjangoEnv = False #hack tro transport django environment - - def __init__(self, file, protocol=None, min_size_to_save= 0): - pickle.Pickler.__init__(self,file,protocol) - self.modules = set() #set of modules needed to depickle - self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env - - def dump(self, obj): - # note: not thread safe - # minimal side-effects, so not fixing - recurse_limit = 3000 - base_recurse = sys.getrecursionlimit() - if base_recurse < recurse_limit: - sys.setrecursionlimit(recurse_limit) - self.inject_addons() - try: - return pickle.Pickler.dump(self, obj) - except RuntimeError, e: - if 'recursion' in e.args[0]: - msg = """Could not pickle object as excessively deep recursion required. - Try _fast_serialization=2 or contact PiCloud support""" - raise pickle.PicklingError(msg) - finally: - new_recurse = sys.getrecursionlimit() - if new_recurse == recurse_limit: - sys.setrecursionlimit(base_recurse) - - def save_buffer(self, obj): - """Fallback to save_string""" - pickle.Pickler.save_string(self,str(obj)) - dispatch[buffer] = save_buffer - - #block broken objects - def save_unsupported(self, obj, pack=None): - raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) - dispatch[types.GeneratorType] = save_unsupported - - #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it - try: - slice(0,1).__reduce__() - except TypeError: #can't pickle - - dispatch[slice] = save_unsupported - - #itertools objects do not pickle! - for v in itertools.__dict__.values(): - if type(v) is type: - dispatch[v] = save_unsupported - - - def save_dict(self, obj): - """hack fix - If the dict is a global, deal with it in a special way - """ - #print 'saving', obj - if obj is __builtins__: - self.save_reduce(_get_module_builtins, (), obj=obj) - else: - pickle.Pickler.save_dict(self, obj) - dispatch[pickle.DictionaryType] = save_dict - - - def save_module(self, obj, pack=struct.pack): - """ - Save a module as an import - """ - #print 'try save import', obj.__name__ - self.modules.add(obj) - self.save_reduce(subimport,(obj.__name__,), obj=obj) - dispatch[types.ModuleType] = save_module #new type - - def save_codeobject(self, obj, pack=struct.pack): - """ - Save a code object - """ - #print 'try to save codeobj: ', obj - args = ( - obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, - obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, - obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars - ) - self.save_reduce(types.CodeType, args, obj=obj) - dispatch[types.CodeType] = save_codeobject #new type - - def save_function(self, obj, name=None, pack=struct.pack): - """ Registered with the dispatch to handle all function types. - - Determines what kind of function obj is (e.g. lambda, defined at - interactive prompt, etc) and handles the pickling appropriately. - """ - write = self.write - - name = obj.__name__ - modname = pickle.whichmodule(obj, name) - #print 'which gives %s %s %s' % (modname, obj, name) - try: - themodule = sys.modules[modname] - except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ - modname = '__main__' - - if modname == '__main__': - themodule = None - - if themodule: - self.modules.add(themodule) - - if not self.savedDjangoEnv: - #hack for django - if we detect the settings module, we transport it - django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') - if django_settings: - django_mod = sys.modules.get(django_settings) - if django_mod: - cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) - self.savedDjangoEnv = True - self.modules.add(django_mod) - write(pickle.MARK) - self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) - write(pickle.POP_MARK) - - - # if func is lambda, def'ed at prompt, is in main, or is nested, then - # we'll pickle the actual function object rather than simply saving a - # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: - #Force server to import modules that have been imported in main - modList = None - if themodule == None and not self.savedForceImports: - mainmod = sys.modules['__main__'] - if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): - modList = list(mainmod.___pyc_forcedImports__) - self.savedForceImports = True - self.save_function_tuple(obj, modList) - return - else: # func is nested - klass = getattr(themodule, name, None) - if klass is None or klass is not obj: - self.save_function_tuple(obj, [themodule]) - return - - if obj.__dict__: - # essentially save_reduce, but workaround needed to avoid recursion - self.save(_restore_attr) - write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') - self.memoize(obj) - self.save(obj.__dict__) - write(pickle.TUPLE + pickle.REDUCE) - else: - write(pickle.GLOBAL + modname + '\n' + name + '\n') - self.memoize(obj) - dispatch[types.FunctionType] = save_function - - def save_function_tuple(self, func, forced_imports): - """ Pickles an actual func object. - - A func comprises: code, globals, defaults, closure, and dict. We - extract and save these, injecting reducing functions at certain points - to recreate the func object. Keep in mind that some of these pieces - can contain a ref to the func itself. Thus, a naive save on these - pieces could trigger an infinite loop of save's. To get around that, - we first create a skeleton func object using just the code (this is - safe, since this won't contain a ref to the func), and memoize it as - soon as it's created. The other stuff can then be filled in later. - """ - save = self.save - write = self.write - - # save the modules (if any) - if forced_imports: - write(pickle.MARK) - save(_modules_to_main) - #print 'forced imports are', forced_imports - - forced_names = map(lambda m: m.__name__, forced_imports) - save((forced_names,)) - - #save((forced_imports,)) - write(pickle.REDUCE) - write(pickle.POP_MARK) - - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) - - save(_fill_function) # skeleton function updater - write(pickle.MARK) # beginning of tuple that _fill_function expects - - # create a skeleton function object and memoize it - save(_make_skel_func) - save((code, len(closure), base_globals)) - write(pickle.REDUCE) - self.memoize(func) - - # save the rest of the func data needed by _fill_function - save(f_globals) - save(defaults) - save(closure) - save(dct) - write(pickle.TUPLE) - write(pickle.REDUCE) # applies _fill_function on the tuple - - @staticmethod - def extract_code_globals(co): - """ - Find all globals names read or written to by codeblock co - """ - code = co.co_code - names = co.co_names - out_names = set() - - n = len(code) - i = 0 - extended_arg = 0 - while i < n: - op = code[i] - - i = i+1 - if op >= HAVE_ARGUMENT: - oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg - extended_arg = 0 - i = i+2 - if op == EXTENDED_ARG: - extended_arg = oparg*65536L - if op in GLOBAL_OPS: - out_names.add(names[oparg]) - #print 'extracted', out_names, ' from ', names - return out_names - - def extract_func_data(self, func): - """ - Turn the function into a tuple of data necessary to recreate it: - code, globals, defaults, closure, dict - """ - code = func.func_code - - # extract all global ref's - func_global_refs = CloudPickler.extract_code_globals(code) - if code.co_consts: # see if nested function have any global refs - for const in code.co_consts: - if type(const) is types.CodeType and const.co_names: - func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) - # process all variables referenced by global environment - f_globals = {} - for var in func_global_refs: - #Some names, such as class functions are not global - we don't need them - if func.func_globals.has_key(var): - f_globals[var] = func.func_globals[var] - - # defaults requires no processing - defaults = func.func_defaults - - def get_contents(cell): - try: - return cell.cell_contents - except ValueError, e: #cell is empty error on not yet assigned - raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') - - - # process closure - if func.func_closure: - closure = map(get_contents, func.func_closure) - else: - closure = [] - - # save the dict - dct = func.func_dict - - if printSerialization: - outvars = ['code: ' + str(code) ] - outvars.append('globals: ' + str(f_globals)) - outvars.append('defaults: ' + str(defaults)) - outvars.append('closure: ' + str(closure)) - print 'function ', func, 'is extracted to: ', ', '.join(outvars) - - base_globals = self.globals_ref.get(id(func.func_globals), {}) - self.globals_ref[id(func.func_globals)] = base_globals - - return (code, f_globals, defaults, closure, dct, base_globals) - - def save_global(self, obj, name=None, pack=struct.pack): - write = self.write - memo = self.memo - - if name is None: - name = obj.__name__ - - modname = getattr(obj, "__module__", None) - if modname is None: - modname = pickle.whichmodule(obj, name) - - try: - __import__(modname) - themodule = sys.modules[modname] - except (ImportError, KeyError, AttributeError): #should never occur - raise pickle.PicklingError( - "Can't pickle %r: Module %s cannot be found" % - (obj, modname)) - - if modname == '__main__': - themodule = None - - if themodule: - self.modules.add(themodule) - - sendRef = True - typ = type(obj) - #print 'saving', obj, typ - try: - try: #Deal with case when getattribute fails with exceptions - klass = getattr(themodule, name) - except (AttributeError): - if modname == '__builtin__': #new.* are misrepeported - modname = 'new' - __import__(modname) - themodule = sys.modules[modname] - try: - klass = getattr(themodule, name) - except AttributeError, a: - #print themodule, name, obj, type(obj) - raise pickle.PicklingError("Can't pickle builtin %s" % obj) - else: - raise - - except (ImportError, KeyError, AttributeError): - if typ == types.TypeType or typ == types.ClassType: - sendRef = False - else: #we can't deal with this - raise - else: - if klass is not obj and (typ == types.TypeType or typ == types.ClassType): - sendRef = False - if not sendRef: - #note: Third party types might crash this - add better checks! - d = dict(obj.__dict__) #copy dict proxy to a dict - if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties - d.pop('__dict__',None) - d.pop('__weakref__',None) - - # hack as __new__ is stored differently in the __dict__ - new_override = d.get('__new__', None) - if new_override: - d['__new__'] = obj.__new__ - - self.save_reduce(type(obj),(obj.__name__,obj.__bases__, - d),obj=obj) - #print 'internal reduce dask %s %s' % (obj, d) - return - - if self.proto >= 2: - code = _extension_registry.get((modname, name)) - if code: - assert code > 0 - if code <= 0xff: - write(pickle.EXT1 + chr(code)) - elif code <= 0xffff: - write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) - else: - write(pickle.EXT4 + pack("= 2 and getattr(func, "__name__", "") == "__newobj__": - #Added fix to allow transient - cls = args[0] - if not hasattr(cls, "__new__"): - raise pickle.PicklingError( - "args[0] from __newobj__ args has no __new__") - if obj is not None and cls is not obj.__class__: - raise pickle.PicklingError( - "args[0] from __newobj__ args has the wrong class") - args = args[1:] - save(cls) - - #Don't pickle transient entries - if hasattr(obj, '__transient__'): - transient = obj.__transient__ - state = state.copy() - - for k in list(state.keys()): - if k in transient: - del state[k] - - save(args) - write(pickle.NEWOBJ) - else: - save(func) - save(args) - write(pickle.REDUCE) - - if obj is not None: - self.memoize(obj) - - # More new special cases (that work with older protocols as - # well): when __reduce__ returns a tuple with 4 or 5 items, - # the 4th and 5th item should be iterators that provide list - # items and dict items (as (key, value) tuples), or None. - - if listitems is not None: - self._batch_appends(listitems) - - if dictitems is not None: - self._batch_setitems(dictitems) - - if state is not None: - #print 'obj %s has state %s' % (obj, state) - save(state) - write(pickle.BUILD) - - - def save_xrange(self, obj): - """Save an xrange object in python 2.5 - Python 2.6 supports this natively - """ - range_params = xrange_params(obj) - self.save_reduce(_build_xrange,range_params) - - #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it - try: - xrange(0).__reduce__() - except TypeError: #can't pickle -- use PiCloud pickler - dispatch[xrange] = save_xrange - - def save_partial(self, obj): - """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" - self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) - - if sys.version_info < (2,7): #2.7 supports partial pickling - dispatch[partial] = save_partial - - - def save_file(self, obj): - """Save a file""" - import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute - from ..transport.adapter import SerializingAdapter - - if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): - raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") - if obj.name == '': - return self.save_reduce(getattr, (sys,'stdout'), obj=obj) - if obj.name == '': - return self.save_reduce(getattr, (sys,'stderr'), obj=obj) - if obj.name == '': - raise pickle.PicklingError("Cannot pickle standard input") - if hasattr(obj, 'isatty') and obj.isatty(): - raise pickle.PicklingError("Cannot pickle files that map to tty objects") - if 'r' not in obj.mode: - raise pickle.PicklingError("Cannot pickle files that are not opened for reading") - name = obj.name - try: - fsize = os.stat(name).st_size - except OSError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) - - if obj.closed: - #create an empty closed string io - retval = pystringIO.StringIO("") - retval.close() - elif not fsize: #empty file - retval = pystringIO.StringIO("") - try: - tmpfile = file(name) - tst = tmpfile.read(1) - except IOError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) - tmpfile.close() - if tst != '': - raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) - elif fsize > SerializingAdapter.max_transmit_data: - raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % - (name,SerializingAdapter.max_transmit_data)) - else: - try: - tmpfile = file(name) - contents = tmpfile.read(SerializingAdapter.max_transmit_data) - tmpfile.close() - except IOError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) - retval = pystringIO.StringIO(contents) - curloc = obj.tell() - retval.seek(curloc) - - retval.name = name - self.save(retval) #save stringIO - self.memoize(obj) - - dispatch[file] = save_file - """Special functions for Add-on libraries""" - - def inject_numpy(self): - numpy = sys.modules.get('numpy') - if not numpy or not hasattr(numpy, 'ufunc'): - return - self.dispatch[numpy.ufunc] = self.__class__.save_ufunc - - numpy_tst_mods = ['numpy', 'scipy.special'] - def save_ufunc(self, obj): - """Hack function for saving numpy ufunc objects""" - name = obj.__name__ - for tst_mod_name in self.numpy_tst_mods: - tst_mod = sys.modules.get(tst_mod_name, None) - if tst_mod: - if name in tst_mod.__dict__: - self.save_reduce(_getobject, (tst_mod_name, name)) - return - raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj)) - - def inject_timeseries(self): - """Handle bugs with pickling scikits timeseries""" - tseries = sys.modules.get('scikits.timeseries.tseries') - if not tseries or not hasattr(tseries, 'Timeseries'): - return - self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries - - def save_timeseries(self, obj): - import scikits.timeseries.tseries as ts - - func, reduce_args, state = obj.__reduce__() - if func != ts._tsreconstruct: - raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func)) - state = (1, - obj.shape, - obj.dtype, - obj.flags.fnc, - obj._data.tostring(), - ts.getmaskarray(obj).tostring(), - obj._fill_value, - obj._dates.shape, - obj._dates.__array__().tostring(), - obj._dates.dtype, #added -- preserve type - obj.freq, - obj._optinfo, - ) - return self.save_reduce(_genTimeSeries, (reduce_args, state)) - - def inject_email(self): - """Block email LazyImporters from being saved""" - email = sys.modules.get('email') - if not email: - return - self.dispatch[email.LazyImporter] = self.__class__.save_unsupported - - def inject_addons(self): - """Plug in system. Register additional pickling functions if modules already loaded""" - self.inject_numpy() - self.inject_timeseries() - self.inject_email() - - """Python Imaging Library""" - def save_image(self, obj): - if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \ - and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()): - #if image not loaded yet -- lazy load - self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj) - else: - #image is loaded - just transmit it over - self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj) - - """ - def memoize(self, obj): - pickle.Pickler.memoize(self, obj) - if printMemoization: - print 'memoizing ' + str(obj) - """ - - - -# Shorthands for legacy support - -def dump(obj, file, protocol=2): - CloudPickler(file, protocol).dump(obj) - -def dumps(obj, protocol=2): - file = StringIO() - - cp = CloudPickler(file,protocol) - cp.dump(obj) - - #print 'cloud dumped', str(obj), str(cp.modules) - - return file.getvalue() - - -#hack for __import__ not working as desired -def subimport(name): - __import__(name) - return sys.modules[name] - -#hack to load django settings: -def django_settings_load(name): - modified_env = False - - if 'DJANGO_SETTINGS_MODULE' not in os.environ: - os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps - modified_env = True - try: - module = subimport(name) - except Exception, i: - print >> sys.stderr, 'Cloud not import django settings %s:' % (name) - print_exec(sys.stderr) - if modified_env: - del os.environ['DJANGO_SETTINGS_MODULE'] - else: - #add project directory to sys,path: - if hasattr(module,'__file__'): - dirname = os.path.split(module.__file__)[0] + '/' - sys.path.append(dirname) - -# restores function attributes -def _restore_attr(obj, attr): - for key, val in attr.items(): - setattr(obj, key, val) - return obj - -def _get_module_builtins(): - return pickle.__builtins__ - -def print_exec(stream): - ei = sys.exc_info() - traceback.print_exception(ei[0], ei[1], ei[2], None, stream) - -def _modules_to_main(modList): - """Force every module in modList to be placed into main""" - if not modList: - return - - main = sys.modules['__main__'] - for modname in modList: - if type(modname) is str: - try: - mod = __import__(modname) - except Exception, i: #catch all... - sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ -A version mismatch is likely. Specific error was:\n' % modname) - print_exec(sys.stderr) - else: - setattr(main,mod.__name__, mod) - else: - #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) - #In old version actual module was sent - setattr(main,modname.__name__, modname) - -#object generators: -def _build_xrange(start, step, len): - """Built xrange explicitly""" - return xrange(start, start + step*len, step) - -def _genpartial(func, args, kwds): - if not args: - args = () - if not kwds: - kwds = {} - return partial(func, *args, **kwds) - - -def _fill_function(func, globals, defaults, closure, dict): - """ Fills in the rest of function data into the skeleton function object - that were created via _make_skel_func(). - """ - func.func_globals.update(globals) - func.func_defaults = defaults - func.func_dict = dict - - if len(closure) != len(func.func_closure): - raise pickle.UnpicklingError("closure lengths don't match up") - for i in range(len(closure)): - _change_cell_value(func.func_closure[i], closure[i]) - - return func - -def _make_skel_func(code, num_closures, base_globals = None): - """ Creates a skeleton function object that contains just the provided - code and the correct number of cells in func_closure. All other - func attributes (e.g. func_globals) are empty. - """ - #build closure (cells): - if not ctypes: - raise Exception('ctypes failed to import; cannot build function') - - cellnew = ctypes.pythonapi.PyCell_New - cellnew.restype = ctypes.py_object - cellnew.argtypes = (ctypes.py_object,) - dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) - - if base_globals is None: - base_globals = {} - base_globals['__builtins__'] = __builtins__ - - return types.FunctionType(code, base_globals, - None, None, dummy_closure) - -# this piece of opaque code is needed below to modify 'cell' contents -cell_changer_code = new.code( - 1, 1, 2, 0, - ''.join([ - chr(dis.opmap['LOAD_FAST']), '\x00\x00', - chr(dis.opmap['DUP_TOP']), - chr(dis.opmap['STORE_DEREF']), '\x00\x00', - chr(dis.opmap['RETURN_VALUE']) - ]), - (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () -) - -def _change_cell_value(cell, newval): - """ Changes the contents of 'cell' object to newval """ - return new.function(cell_changer_code, {}, None, (), (cell,))(newval) - -"""Constructors for 3rd party libraries -Note: These can never be renamed due to client compatibility issues""" - -def _getobject(modname, attribute): - mod = __import__(modname) - return mod.__dict__[attribute] - -def _generateImage(size, mode, str_rep): - """Generate image from string representation""" - import Image - i = Image.new(mode, size) - i.fromstring(str_rep) - return i - -def _lazyloadImage(fp): - import Image - fp.seek(0) #works in almost any case - return Image.open(fp) - -"""Timeseries""" -def _genTimeSeries(reduce_args, state): - import scikits.timeseries.tseries as ts - from numpy import ndarray - from numpy.ma import MaskedArray - - - time_series = ts._tsreconstruct(*reduce_args) - - #from setstate modified - (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state - #print 'regenerating %s' % dtyp - - MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) - _dates = time_series._dates - #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ - ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) - _dates.freq = frq - _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, - toobj=None, toord=None, tostr=None)) - # Update the _optinfo dictionary - time_series._optinfo.update(infodict) - return time_series - diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py deleted file mode 100644 index 6172d69dcf..0000000000 --- a/pyspark/pyspark/context.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -import atexit -from tempfile import NamedTemporaryFile - -from pyspark.broadcast import Broadcast -from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length, batched -from pyspark.rdd import RDD - -from py4j.java_collections import ListConverter - - -class SparkContext(object): - """ - Main entry point for Spark functionality. A SparkContext represents the - connection to a Spark cluster, and can be used to create L{RDD}s and - broadcast variables on that cluster. - """ - - gateway = launch_gateway() - jvm = gateway.jvm - _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile - _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile - - def __init__(self, master, jobName, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024): - """ - Create a new SparkContext. - - @param master: Cluster URL to connect to - (e.g. mesos://host:port, spark://host:port, local[4]). - @param jobName: A name for your job, to display on the cluster web UI - @param sparkHome: Location where Spark is installed on cluster nodes. - @param pyFiles: Collection of .zip or .py files to send to the cluster - and add to PYTHONPATH. These can be paths on the local file - system or HDFS, HTTP, HTTPS, or FTP URLs. - @param environment: A dictionary of environment variables to set on - worker nodes. - @param batchSize: The number of Python objects represented as a single - Java object. Set 1 to disable batching or -1 to use an - unlimited batch size. - """ - self.master = master - self.jobName = jobName - self.sparkHome = sparkHome or None # None becomes null in Py4J - self.environment = environment or {} - self.batchSize = batchSize # -1 represents a unlimited batch size - - # Create the Java SparkContext through Py4J - empty_string_array = self.gateway.new_array(self.jvm.String, 0) - self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, - empty_string_array) - - self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') - # Broadcast's __reduce__ method stores Broadcast instances here. - # This allows other code to determine which Broadcast instances have - # been pickled, so it can determine which Java broadcast objects to - # send. - self._pickled_broadcast_vars = set() - - # Deploy any code dependencies specified in the constructor - for path in (pyFiles or []): - self.addPyFile(path) - - @property - def defaultParallelism(self): - """ - Default level of parallelism to use when not given by user (e.g. for - reduce tasks) - """ - return self._jsc.sc().defaultParallelism() - - def __del__(self): - if self._jsc: - self._jsc.stop() - - def stop(self): - """ - Shut down the SparkContext. - """ - self._jsc.stop() - self._jsc = None - - def parallelize(self, c, numSlices=None): - """ - Distribute a local Python collection to form an RDD. - """ - numSlices = numSlices or self.defaultParallelism - # Calling the Java parallelize() method with an ArrayList is too slow, - # because it sends O(n) Py4J commands. As an alternative, serialized - # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False) - atexit.register(lambda: os.unlink(tempFile.name)) - if self.batchSize != 1: - c = batched(c, self.batchSize) - for x in c: - write_with_length(dump_pickle(x), tempFile) - tempFile.close() - jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) - return RDD(jrdd, self) - - def textFile(self, name, minSplits=None): - """ - Read a text file from HDFS, a local file system (available on all - nodes), or any Hadoop-supported file system URI, and return it as an - RDD of Strings. - """ - minSplits = minSplits or min(self.defaultParallelism, 2) - jrdd = self._jsc.textFile(name, minSplits) - return RDD(jrdd, self) - - def union(self, rdds): - """ - Build the union of a list of RDDs. - """ - first = rdds[0]._jrdd - rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self.gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self) - - def broadcast(self, value): - """ - Broadcast a read-only variable to the cluster, returning a C{Broadcast} - object for reading it in distributed functions. The variable will be - sent to each cluster only once. - """ - jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) - return Broadcast(jbroadcast.id(), value, jbroadcast, - self._pickled_broadcast_vars) - - def addFile(self, path): - """ - Add a file to be downloaded into the working directory of this Spark - job on every node. The C{path} passed can be either a local file, - a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, - HTTPS or FTP URI. - """ - self._jsc.sc().addFile(path) - - def clearFiles(self): - """ - Clear the job's list of files added by L{addFile} or L{addPyFile} so - that they do not get downloaded to any new nodes. - """ - # TODO: remove added .py or .zip files from the PYTHONPATH? - self._jsc.sc().clearFiles() - - def addPyFile(self, path): - """ - Add a .py or .zip dependency for all tasks to be executed on this - SparkContext in the future. The C{path} passed can be either a local - file, a file in HDFS (or other Hadoop-supported filesystems), or an - HTTP, HTTPS or FTP URI. - """ - self.addFile(path) - filename = path.split("/")[-1] - os.environ["PYTHONPATH"] = \ - "%s:%s" % (filename, os.environ["PYTHONPATH"]) diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py deleted file mode 100644 index 2329e536cc..0000000000 --- a/pyspark/pyspark/java_gateway.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import sys -from subprocess import Popen, PIPE -from threading import Thread -from py4j.java_gateway import java_import, JavaGateway, GatewayClient - - -SPARK_HOME = os.environ["SPARK_HOME"] - - -def launch_gateway(): - # Launch the Py4j gateway using Spark's run command so that we pick up the - # proper classpath and SPARK_MEM settings from spark-env.sh - command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer", - "--die-on-broken-pipe", "0"] - proc = Popen(command, stdout=PIPE, stdin=PIPE) - # Determine which ephemeral port the server started on: - port = int(proc.stdout.readline()) - # Create a thread to echo output from the GatewayServer, which is required - # for Java log output to show up: - class EchoOutputThread(Thread): - def __init__(self, stream): - Thread.__init__(self) - self.daemon = True - self.stream = stream - - def run(self): - while True: - line = self.stream.readline() - sys.stderr.write(line) - EchoOutputThread(proc.stdout).start() - # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) - # Import the classes used by PySpark - java_import(gateway.jvm, "spark.api.java.*") - java_import(gateway.jvm, "spark.api.python.*") - java_import(gateway.jvm, "scala.Tuple2") - return gateway diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py deleted file mode 100644 index 7036c47980..0000000000 --- a/pyspark/pyspark/join.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Copyright (c) 2011, Douban Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - - * Neither the name of the Douban Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" - - -def _do_python_join(rdd, other, numSplits, dispatch): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) - return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch) - - -def python_join(rdd, other, numSplits): - def dispatch(seq): - vbuf, wbuf = [], [] - for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return [(v, w) for v in vbuf for w in wbuf] - return _do_python_join(rdd, other, numSplits, dispatch) - - -def python_right_outer_join(rdd, other, numSplits): - def dispatch(seq): - vbuf, wbuf = [], [] - for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - if not vbuf: - vbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] - return _do_python_join(rdd, other, numSplits, dispatch) - - -def python_left_outer_join(rdd, other, numSplits): - def dispatch(seq): - vbuf, wbuf = [], [] - for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - if not wbuf: - wbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] - return _do_python_join(rdd, other, numSplits, dispatch) - - -def python_cogroup(rdd, other, numSplits): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) - def dispatch(seq): - vbuf, wbuf = [], [] - for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return (vbuf, wbuf) - return vs.union(ws).groupByKey(numSplits).mapValues(dispatch) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py deleted file mode 100644 index cbffb6cc1f..0000000000 --- a/pyspark/pyspark/rdd.py +++ /dev/null @@ -1,713 +0,0 @@ -import atexit -from base64 import standard_b64encode as b64enc -import copy -from collections import defaultdict -from itertools import chain, ifilter, imap, product -import operator -import os -import shlex -from subprocess import Popen, PIPE -from tempfile import NamedTemporaryFile -from threading import Thread - -from pyspark import cloudpickle -from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ - read_from_pickle_file -from pyspark.join import python_join, python_left_outer_join, \ - python_right_outer_join, python_cogroup - -from py4j.java_collections import ListConverter, MapConverter - - -__all__ = ["RDD"] - - -class RDD(object): - """ - A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - Represents an immutable, partitioned collection of elements that can be - operated on in parallel. - """ - - def __init__(self, jrdd, ctx): - self._jrdd = jrdd - self.is_cached = False - self.ctx = ctx - - @property - def context(self): - """ - The L{SparkContext} that this RDD was created on. - """ - return self.ctx - - def cache(self): - """ - Persist this RDD with the default storage level (C{MEMORY_ONLY}). - """ - self.is_cached = True - self._jrdd.cache() - return self - - # TODO persist(self, storageLevel) - - def map(self, f, preservesPartitioning=False): - """ - Return a new RDD containing the distinct elements in this RDD. - """ - def func(iterator): return imap(f, iterator) - return PipelinedRDD(self, func, preservesPartitioning) - - def flatMap(self, f, preservesPartitioning=False): - """ - Return a new RDD by first applying a function to all elements of this - RDD, and then flattening the results. - - >>> rdd = sc.parallelize([2, 3, 4]) - >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) - [1, 1, 1, 2, 2, 3] - >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) - [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] - """ - def func(iterator): return chain.from_iterable(imap(f, iterator)) - return self.mapPartitions(func, preservesPartitioning) - - def mapPartitions(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition of this RDD. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 2) - >>> def f(iterator): yield sum(iterator) - >>> rdd.mapPartitions(f).collect() - [3, 7] - """ - return PipelinedRDD(self, f, preservesPartitioning) - - # TODO: mapPartitionsWithSplit - - def filter(self, f): - """ - Return a new RDD containing only the elements that satisfy a predicate. - - >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) - >>> rdd.filter(lambda x: x % 2 == 0).collect() - [2, 4] - """ - def func(iterator): return ifilter(f, iterator) - return self.mapPartitions(func) - - def distinct(self): - """ - Return a new RDD containing the distinct elements in this RDD. - - >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) - [1, 2, 3] - """ - return self.map(lambda x: (x, "")) \ - .reduceByKey(lambda x, _: x) \ - .map(lambda (x, _): x) - - # TODO: sampling needs to be re-implemented due to Batch - #def sample(self, withReplacement, fraction, seed): - # jrdd = self._jrdd.sample(withReplacement, fraction, seed) - # return RDD(jrdd, self.ctx) - - #def takeSample(self, withReplacement, num, seed): - # vals = self._jrdd.takeSample(withReplacement, num, seed) - # return [load_pickle(bytes(x)) for x in vals] - - def union(self, other): - """ - Return the union of this RDD and another one. - - >>> rdd = sc.parallelize([1, 1, 2, 3]) - >>> rdd.union(rdd).collect() - [1, 1, 2, 3, 1, 1, 2, 3] - """ - return RDD(self._jrdd.union(other._jrdd), self.ctx) - - def __add__(self, other): - """ - Return the union of this RDD and another one. - - >>> rdd = sc.parallelize([1, 1, 2, 3]) - >>> (rdd + rdd).collect() - [1, 1, 2, 3, 1, 1, 2, 3] - """ - if not isinstance(other, RDD): - raise TypeError - return self.union(other) - - # TODO: sort - - def glom(self): - """ - Return an RDD created by coalescing all elements within each partition - into a list. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 2) - >>> sorted(rdd.glom().collect()) - [[1, 2], [3, 4]] - """ - def func(iterator): yield list(iterator) - return self.mapPartitions(func) - - def cartesian(self, other): - """ - Return the Cartesian product of this RDD and another one, that is, the - RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and - C{b} is in C{other}. - - >>> rdd = sc.parallelize([1, 2]) - >>> sorted(rdd.cartesian(rdd).collect()) - [(1, 1), (1, 2), (2, 1), (2, 2)] - """ - # Due to batching, we can't use the Java cartesian method. - java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - def unpack_batches(pair): - (x, y) = pair - if type(x) == Batch or type(y) == Batch: - xs = x.items if type(x) == Batch else [x] - ys = y.items if type(y) == Batch else [y] - for pair in product(xs, ys): - yield pair - else: - yield pair - return java_cartesian.flatMap(unpack_batches) - - def groupBy(self, f, numSplits=None): - """ - Return an RDD of grouped items. - - >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) - >>> result = rdd.groupBy(lambda x: x % 2).collect() - >>> sorted([(x, sorted(y)) for (x, y) in result]) - [(0, [2, 8]), (1, [1, 1, 3, 5])] - """ - return self.map(lambda x: (f(x), x)).groupByKey(numSplits) - - def pipe(self, command, env={}): - """ - Return an RDD created by piping elements to a forked external process. - - >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() - ['1', '2', '3'] - """ - def func(iterator): - pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) - def pipe_objs(out): - for obj in iterator: - out.write(str(obj).rstrip('\n') + '\n') - out.close() - Thread(target=pipe_objs, args=[pipe.stdin]).start() - return (x.rstrip('\n') for x in pipe.stdout) - return self.mapPartitions(func) - - def foreach(self, f): - """ - Applies a function to all elements of this RDD. - - >>> def f(x): print x - >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) - """ - self.map(f).collect() # Force evaluation - - def collect(self): - """ - Return a list that contains all of the elements in this RDD. - """ - picklesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(picklesInJava)) - - def _collect_iterator_through_file(self, iterator): - # Transferring lots of data through Py4J can be slow because - # socket.readline() is inefficient. Instead, we'll dump the data to a - # file and read it back. - tempFile = NamedTemporaryFile(delete=False) - tempFile.close() - def clean_up_file(): - try: os.unlink(tempFile.name) - except: pass - atexit.register(clean_up_file) - self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - for item in read_from_pickle_file(tempFile): - yield item - os.unlink(tempFile.name) - - def reduce(self, f): - """ - Reduces the elements of this RDD using the specified associative binary - operator. - - >>> from operator import add - >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) - 15 - >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) - 10 - """ - def func(iterator): - acc = None - for obj in iterator: - if acc is None: - acc = obj - else: - acc = f(obj, acc) - if acc is not None: - yield acc - vals = self.mapPartitions(func).collect() - return reduce(f, vals) - - def fold(self, zeroValue, op): - """ - Aggregate the elements of each partition, and then the results for all - the partitions, using a given associative function and a neutral "zero - value." - - The function C{op(t1, t2)} is allowed to modify C{t1} and return it - as its result value to avoid object allocation; however, it should not - modify C{t2}. - - >>> from operator import add - >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) - 15 - """ - def func(iterator): - acc = zeroValue - for obj in iterator: - acc = op(obj, acc) - yield acc - vals = self.mapPartitions(func).collect() - return reduce(op, vals, zeroValue) - - # TODO: aggregate - - def sum(self): - """ - Add up the elements in this RDD. - - >>> sc.parallelize([1.0, 2.0, 3.0]).sum() - 6.0 - """ - return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) - - def count(self): - """ - Return the number of elements in this RDD. - - >>> sc.parallelize([2, 3, 4]).count() - 3 - """ - return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() - - def countByValue(self): - """ - Return the count of each unique value in this RDD as a dictionary of - (value, count) pairs. - - >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items()) - [(1, 2), (2, 3)] - """ - def countPartition(iterator): - counts = defaultdict(int) - for obj in iterator: - counts[obj] += 1 - yield counts - def mergeMaps(m1, m2): - for (k, v) in m2.iteritems(): - m1[k] += v - return m1 - return self.mapPartitions(countPartition).reduce(mergeMaps) - - def take(self, num): - """ - Take the first num elements of the RDD. - - This currently scans the partitions *one by one*, so it will be slow if - a lot of partitions are required. In that case, use L{collect} to get - the whole RDD instead. - - >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) - [2, 3] - >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) - [2, 3, 4, 5, 6] - """ - items = [] - splits = self._jrdd.splits() - taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) - while len(items) < num and splits: - split = splits.pop(0) - iterator = self._jrdd.iterator(split, taskContext) - items.extend(self._collect_iterator_through_file(iterator)) - return items[:num] - - def first(self): - """ - Return the first element in this RDD. - - >>> sc.parallelize([2, 3, 4]).first() - 2 - """ - return self.take(1)[0] - - def saveAsTextFile(self, path): - """ - Save this RDD as a text file, using string representations of elements. - - >>> tempFile = NamedTemporaryFile(delete=True) - >>> tempFile.close() - >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) - >>> from fileinput import input - >>> from glob import glob - >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) - '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' - """ - def func(iterator): - return (str(x).encode("utf-8") for x in iterator) - keyed = PipelinedRDD(self, func) - keyed._bypass_serializer = True - keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) - - # Pair functions - - def collectAsMap(self): - """ - Return the key-value pairs in this RDD to the master as a dictionary. - - >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() - >>> m[1] - 2 - >>> m[3] - 4 - """ - return dict(self.collect()) - - def reduceByKey(self, func, numSplits=None): - """ - Merge the values for each key using an associative reduce function. - - This will also perform the merging locally on each mapper before - sending results to a reducer, similarly to a "combiner" in MapReduce. - - Output will be hash-partitioned with C{numSplits} splits, or the - default parallelism level if C{numSplits} is not specified. - - >>> from operator import add - >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(rdd.reduceByKey(add).collect()) - [('a', 2), ('b', 1)] - """ - return self.combineByKey(lambda x: x, func, func, numSplits) - - def reduceByKeyLocally(self, func): - """ - Merge the values for each key using an associative reduce function, but - return the results immediately to the master as a dictionary. - - This will also perform the merging locally on each mapper before - sending results to a reducer, similarly to a "combiner" in MapReduce. - - >>> from operator import add - >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(rdd.reduceByKeyLocally(add).items()) - [('a', 2), ('b', 1)] - """ - def reducePartition(iterator): - m = {} - for (k, v) in iterator: - m[k] = v if k not in m else func(m[k], v) - yield m - def mergeMaps(m1, m2): - for (k, v) in m2.iteritems(): - m1[k] = v if k not in m1 else func(m1[k], v) - return m1 - return self.mapPartitions(reducePartition).reduce(mergeMaps) - - def countByKey(self): - """ - Count the number of elements for each key, and return the result to the - master as a dictionary. - - >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(rdd.countByKey().items()) - [('a', 2), ('b', 1)] - """ - return self.map(lambda x: x[0]).countByValue() - - def join(self, other, numSplits=None): - """ - Return an RDD containing all pairs of elements with matching keys in - C{self} and C{other}. - - Each pair of elements will be returned as a (k, (v1, v2)) tuple, where - (k, v1) is in C{self} and (k, v2) is in C{other}. - - Performs a hash join across the cluster. - - >>> x = sc.parallelize([("a", 1), ("b", 4)]) - >>> y = sc.parallelize([("a", 2), ("a", 3)]) - >>> sorted(x.join(y).collect()) - [('a', (1, 2)), ('a', (1, 3))] - """ - return python_join(self, other, numSplits) - - def leftOuterJoin(self, other, numSplits=None): - """ - Perform a left outer join of C{self} and C{other}. - - For each element (k, v) in C{self}, the resulting RDD will either - contain all pairs (k, (v, w)) for w in C{other}, or the pair - (k, (v, None)) if no elements in other have key k. - - Hash-partitions the resulting RDD into the given number of partitions. - - >>> x = sc.parallelize([("a", 1), ("b", 4)]) - >>> y = sc.parallelize([("a", 2)]) - >>> sorted(x.leftOuterJoin(y).collect()) - [('a', (1, 2)), ('b', (4, None))] - """ - return python_left_outer_join(self, other, numSplits) - - def rightOuterJoin(self, other, numSplits=None): - """ - Perform a right outer join of C{self} and C{other}. - - For each element (k, w) in C{other}, the resulting RDD will either - contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w)) - if no elements in C{self} have key k. - - Hash-partitions the resulting RDD into the given number of partitions. - - >>> x = sc.parallelize([("a", 1), ("b", 4)]) - >>> y = sc.parallelize([("a", 2)]) - >>> sorted(y.rightOuterJoin(x).collect()) - [('a', (2, 1)), ('b', (None, 4))] - """ - return python_right_outer_join(self, other, numSplits) - - # TODO: add option to control map-side combining - def partitionBy(self, numSplits, hashFunc=hash): - """ - Return a copy of the RDD partitioned using the specified partitioner. - - >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) - >>> sets = pairs.partitionBy(2).glom().collect() - >>> set(sets[0]).intersection(set(sets[1])) - set([]) - """ - if numSplits is None: - numSplits = self.ctx.defaultParallelism - # Transferring O(n) objects to Java is too expensive. Instead, we'll - # form the hash buckets in Python, transferring O(numSplits) objects - # to Java. Each object is a (splitNumber, [objects]) pair. - def add_shuffle_key(iterator): - buckets = defaultdict(list) - for (k, v) in iterator: - buckets[hashFunc(k) % numSplits].append((k, v)) - for (split, items) in buckets.iteritems(): - yield str(split) - yield dump_pickle(Batch(items)) - keyed = PipelinedRDD(self, add_shuffle_key) - keyed._bypass_serializer = True - pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - jrdd = pairRDD.partitionBy(partitioner) - jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) - - # TODO: add control over map-side aggregation - def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numSplits=None): - """ - Generic function to combine the elements for each key using a custom - set of aggregation functions. - - Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined - type" C. Note that V and C can be different -- for example, one might - group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). - - Users provide three functions: - - - C{createCombiner}, which turns a V into a C (e.g., creates - a one-element list) - - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of - a list) - - C{mergeCombiners}, to combine two C's into a single one. - - In addition, users can control the partitioning of the output RDD. - - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> def f(x): return x - >>> def add(a, b): return a + str(b) - >>> sorted(x.combineByKey(str, add, add).collect()) - [('a', '11'), ('b', '1')] - """ - if numSplits is None: - numSplits = self.ctx.defaultParallelism - def combineLocally(iterator): - combiners = {} - for (k, v) in iterator: - if k not in combiners: - combiners[k] = createCombiner(v) - else: - combiners[k] = mergeValue(combiners[k], v) - return combiners.iteritems() - locally_combined = self.mapPartitions(combineLocally) - shuffled = locally_combined.partitionBy(numSplits) - def _mergeCombiners(iterator): - combiners = {} - for (k, v) in iterator: - if not k in combiners: - combiners[k] = v - else: - combiners[k] = mergeCombiners(combiners[k], v) - return combiners.iteritems() - return shuffled.mapPartitions(_mergeCombiners) - - # TODO: support variant with custom partitioner - def groupByKey(self, numSplits=None): - """ - Group the values for each key in the RDD into a single sequence. - Hash-partitions the resulting RDD with into numSplits partitions. - - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(x.groupByKey().collect()) - [('a', [1, 1]), ('b', [1])] - """ - - def createCombiner(x): - return [x] - - def mergeValue(xs, x): - xs.append(x) - return xs - - def mergeCombiners(a, b): - return a + b - - return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numSplits) - - # TODO: add tests - def flatMapValues(self, f): - """ - Pass each value in the key-value pair RDD through a flatMap function - without changing the keys; this also retains the original RDD's - partitioning. - """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) - return self.flatMap(flat_map_fn, preservesPartitioning=True) - - def mapValues(self, f): - """ - Pass each value in the key-value pair RDD through a map function - without changing the keys; this also retains the original RDD's - partitioning. - """ - map_values_fn = lambda (k, v): (k, f(v)) - return self.map(map_values_fn, preservesPartitioning=True) - - # TODO: support varargs cogroup of several RDDs. - def groupWith(self, other): - """ - Alias for cogroup. - """ - return self.cogroup(other) - - # TODO: add variant with custom parittioner - def cogroup(self, other, numSplits=None): - """ - For each key k in C{self} or C{other}, return a resulting RDD that - contains a tuple with the list of values for that key in C{self} as well - as C{other}. - - >>> x = sc.parallelize([("a", 1), ("b", 4)]) - >>> y = sc.parallelize([("a", 2)]) - >>> sorted(x.cogroup(y).collect()) - [('a', ([1], [2])), ('b', ([4], []))] - """ - return python_cogroup(self, other, numSplits) - - # TODO: `lookup` is disabled because we can't make direct comparisons based - # on the key; we need to compare the hash of the key to the hash of the - # keys in the pairs. This could be an expensive operation, since those - # hashes aren't retained. - - -class PipelinedRDD(RDD): - """ - Pipelined maps: - >>> rdd = sc.parallelize([1, 2, 3, 4]) - >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() - [4, 8, 12, 16] - >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() - [4, 8, 12, 16] - - Pipelined reduces: - >>> from operator import add - >>> rdd.map(lambda x: 2 * x).reduce(add) - 20 - >>> rdd.flatMap(lambda x: [x, x]).reduce(add) - 20 - """ - def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and not prev.is_cached: - prev_func = prev.func - def pipeline_func(iterator): - return func(prev_func(iterator)) - self.func = pipeline_func - self.preservesPartitioning = \ - prev.preservesPartitioning and preservesPartitioning - self._prev_jrdd = prev._prev_jrdd - else: - self.func = func - self.preservesPartitioning = preservesPartitioning - self._prev_jrdd = prev._jrdd - self.is_cached = False - self.ctx = prev.ctx - self.prev = prev - self._jrdd_val = None - self._bypass_serializer = False - - @property - def _jrdd(self): - if self._jrdd_val: - return self._jrdd_val - func = self.func - if not self._bypass_serializer and self.ctx.batchSize != 1: - oldfunc = self.func - batchSize = self.ctx.batchSize - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) - func = batched_func - cmds = [func, self._bypass_serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx.gateway._gateway_client) - self.ctx._pickled_broadcast_vars.clear() - class_manifest = self._prev_jrdd.classManifest() - env = copy.copy(self.ctx.environment) - env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") - env = MapConverter().convert(env, self.ctx.gateway._gateway_client) - python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) - self._jrdd_val = python_rdd.asJavaRDD() - return self._jrdd_val - - -def _test(): - import doctest - from pyspark.context import SparkContext - globs = globals().copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - doctest.testmod(globs=globs) - globs['sc'].stop() - - -if __name__ == "__main__": - _test() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py deleted file mode 100644 index 9a5151ea00..0000000000 --- a/pyspark/pyspark/serializers.py +++ /dev/null @@ -1,78 +0,0 @@ -import struct -import cPickle - - -class Batch(object): - """ - Used to store multiple RDD entries as a single Java object. - - This relieves us from having to explicitly track whether an RDD - is stored as batches of objects and avoids problems when processing - the union() of batched and unbatched RDDs (e.g. the union() of textFile() - with another RDD). - """ - def __init__(self, items): - self.items = items - - -def batched(iterator, batchSize): - if batchSize == -1: # unlimited batch size - yield Batch(list(iterator)) - else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: - yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) - - -def dump_pickle(obj): - return cPickle.dumps(obj, 2) - - -load_pickle = cPickle.loads - - -def read_long(stream): - length = stream.read(8) - if length == "": - raise EOFError - return struct.unpack("!q", length)[0] - - -def read_int(stream): - length = stream.read(4) - if length == "": - raise EOFError - return struct.unpack("!i", length)[0] - -def write_with_length(obj, stream): - stream.write(struct.pack("!i", len(obj))) - stream.write(obj) - - -def read_with_length(stream): - length = read_int(stream) - obj = stream.read(length) - if obj == "": - raise EOFError - return obj - - -def read_from_pickle_file(stream): - try: - while True: - obj = load_pickle(read_with_length(stream)) - if type(obj) == Batch: # We don't care about inheritance - for item in obj.items: - yield item - else: - yield obj - except EOFError: - return diff --git a/pyspark/pyspark/shell.py b/pyspark/pyspark/shell.py deleted file mode 100644 index bd39b0283f..0000000000 --- a/pyspark/pyspark/shell.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -An interactive shell. -""" -import optparse # I prefer argparse, but it's not included with Python < 2.7 -import code -import sys - -from pyspark.context import SparkContext - - -def main(master='local', ipython=False): - sc = SparkContext(master, 'PySparkShell') - user_ns = {'sc' : sc} - banner = "Spark context avaiable as sc." - if ipython: - import IPython - IPython.embed(user_ns=user_ns, banner2=banner) - else: - print banner - code.interact(local=user_ns) - - -if __name__ == '__main__': - usage = "usage: %prog [options] master" - parser = optparse.OptionParser(usage=usage) - parser.add_option("-i", "--ipython", help="Run IPython shell", - action="store_true") - (options, args) = parser.parse_args() - if len(sys.argv) > 1: - master = args[0] - else: - master = 'local' - main(master, options.ipython) diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py deleted file mode 100644 index 9f6b507dbd..0000000000 --- a/pyspark/pyspark/worker.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Worker that receives input from Piped RDD. -""" -import sys -from base64 import standard_b64decode -# CloudPickler needs to be imported so that depicklers are registered using the -# copy_reg module. -from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import write_with_length, read_with_length, \ - read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file - - -# Redirect stdout to stderr so that users must return values from functions. -old_stdout = sys.stdout -sys.stdout = sys.stderr - - -def load_obj(): - return load_pickle(standard_b64decode(sys.stdin.readline().strip())) - - -def main(): - num_broadcast_variables = read_int(sys.stdin) - for _ in range(num_broadcast_variables): - bid = read_long(sys.stdin) - value = read_with_length(sys.stdin) - _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) - func = load_obj() - bypassSerializer = load_obj() - if bypassSerializer: - dumps = lambda x: x - else: - dumps = dump_pickle - for obj in func(read_from_pickle_file(sys.stdin)): - write_with_length(dumps(obj), old_stdout) - - -if __name__ == '__main__': - main() diff --git a/pyspark/run-pyspark b/pyspark/run-pyspark deleted file mode 100755 index 4d10fbea8b..0000000000 --- a/pyspark/run-pyspark +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env bash - -# Figure out where the Scala framework is installed -FWDIR="$(cd `dirname $0`; cd ../; pwd)" - -# Export this as SPARK_HOME -export SPARK_HOME="$FWDIR" - -# Load environment variables from conf/spark-env.sh, if it exists -if [ -e $FWDIR/conf/spark-env.sh ] ; then - . $FWDIR/conf/spark-env.sh -fi - -# Figure out which Python executable to use -if [ -z "$PYSPARK_PYTHON" ] ; then - PYSPARK_PYTHON="python" -fi -export PYSPARK_PYTHON - -# Add the PySpark classes to the Python path: -export PYTHONPATH=$SPARK_HOME/pyspark/:$PYTHONPATH - -# Launch with `scala` by default: -if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then - export SPARK_LAUNCH_WITH_SCALA=1 -fi - -exec "$PYSPARK_PYTHON" "$@" diff --git a/python/.gitignore b/python/.gitignore new file mode 100644 index 0000000000..5c56e638f9 --- /dev/null +++ b/python/.gitignore @@ -0,0 +1,2 @@ +*.pyc +docs/ diff --git a/python/epydoc.conf b/python/epydoc.conf new file mode 100644 index 0000000000..91ac984ba2 --- /dev/null +++ b/python/epydoc.conf @@ -0,0 +1,19 @@ +[epydoc] # Epydoc section marker (required by ConfigParser) + +# Information about the project. +name: PySpark +url: http://spark-project.org + +# The list of modules to document. Modules can be named using +# dotted names, module filenames, or package directory names. +# This option may be repeated. +modules: pyspark + +# Write html output to the directory "apidocs" +output: html +target: docs/ + +private: no + +exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers + pyspark.java_gateway pyspark.examples pyspark.shell diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py new file mode 100644 index 0000000000..ad2be21178 --- /dev/null +++ b/python/examples/kmeans.py @@ -0,0 +1,52 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +import sys + +import numpy as np +from pyspark import SparkContext + + +def parseVector(line): + return np.array([float(x) for x in line.split(' ')]) + + +def closestPoint(p, centers): + bestIndex = 0 + closest = float("+inf") + for i in range(len(centers)): + tempDist = np.sum((p - centers[i]) ** 2) + if tempDist < closest: + closest = tempDist + bestIndex = i + return bestIndex + + +if __name__ == "__main__": + if len(sys.argv) < 5: + print >> sys.stderr, \ + "Usage: PythonKMeans " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + lines = sc.textFile(sys.argv[2]) + data = lines.map(parseVector).cache() + K = int(sys.argv[3]) + convergeDist = float(sys.argv[4]) + + kPoints = data.takeSample(False, K, 34) + tempDist = 1.0 + + while tempDist > convergeDist: + closest = data.map( + lambda p : (closestPoint(p, kPoints), (p, 1))) + pointStats = closest.reduceByKey( + lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) + newPoints = pointStats.map( + lambda (x, (y, z)): (x, y / z)).collect() + + tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) + + for (x, y) in newPoints: + kPoints[x] = y + + print "Final centers: " + str(kPoints) diff --git a/python/examples/logistic_regression.py b/python/examples/logistic_regression.py new file mode 100755 index 0000000000..f13698a86f --- /dev/null +++ b/python/examples/logistic_regression.py @@ -0,0 +1,57 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from collections import namedtuple +from math import exp +from os.path import realpath +import sys + +import numpy as np +from pyspark import SparkContext + + +N = 100000 # Number of data points +D = 10 # Number of dimensions +R = 0.7 # Scaling factor +ITERATIONS = 5 +np.random.seed(42) + + +DataPoint = namedtuple("DataPoint", ['x', 'y']) +from lr import DataPoint # So that DataPoint is properly serialized + + +def generateData(): + def generatePoint(i): + y = -1 if i % 2 == 0 else 1 + x = np.random.normal(size=D) + (y * R) + return DataPoint(x, y) + return [generatePoint(i) for i in range(N)] + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonLR []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + points = sc.parallelize(generateData(), slices).cache() + + # Initialize w to a random value + w = 2 * np.random.ranf(size=D) - 1 + print "Initial w: " + str(w) + + def add(x, y): + x += y + return x + + for i in range(1, ITERATIONS + 1): + print "On iteration %i" % i + + gradient = points.map(lambda p: + (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x + ).reduce(add) + w -= gradient + + print "Final w: " + str(w) diff --git a/python/examples/pi.py b/python/examples/pi.py new file mode 100644 index 0000000000..127cba029b --- /dev/null +++ b/python/examples/pi.py @@ -0,0 +1,21 @@ +import sys +from random import random +from operator import add + +from pyspark import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonPi []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonPi") + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + n = 100000 * slices + def f(_): + x = random() * 2 - 1 + y = random() * 2 - 1 + return 1 if x ** 2 + y ** 2 < 1 else 0 + count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + print "Pi is roughly %f" % (4.0 * count / n) diff --git a/python/examples/transitive_closure.py b/python/examples/transitive_closure.py new file mode 100644 index 0000000000..73f7f8fbaf --- /dev/null +++ b/python/examples/transitive_closure.py @@ -0,0 +1,50 @@ +import sys +from random import Random + +from pyspark import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonTC") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelize(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.map(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/python/examples/wordcount.py b/python/examples/wordcount.py new file mode 100644 index 0000000000..857160624b --- /dev/null +++ b/python/examples/wordcount.py @@ -0,0 +1,19 @@ +import sys +from operator import add + +from pyspark import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print >> sys.stderr, \ + "Usage: PythonWordCount " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonWordCount") + lines = sc.textFile(sys.argv[2], 1) + counts = lines.flatMap(lambda x: x.split(' ')) \ + .map(lambda x: (x, 1)) \ + .reduceByKey(add) + output = counts.collect() + for (word, count) in output: + print "%s : %i" % (word, count) diff --git a/python/lib/PY4J_LICENSE.txt b/python/lib/PY4J_LICENSE.txt new file mode 100644 index 0000000000..a70279ca14 --- /dev/null +++ b/python/lib/PY4J_LICENSE.txt @@ -0,0 +1,27 @@ + +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/python/lib/PY4J_VERSION.txt b/python/lib/PY4J_VERSION.txt new file mode 100644 index 0000000000..04a0cd52a8 --- /dev/null +++ b/python/lib/PY4J_VERSION.txt @@ -0,0 +1 @@ +b7924aabe9c5e63f0a4d8bbd17019534c7ec014e diff --git a/python/lib/py4j0.7.egg b/python/lib/py4j0.7.egg new file mode 100644 index 0000000000..f8a339d8ee Binary files /dev/null and b/python/lib/py4j0.7.egg differ diff --git a/python/lib/py4j0.7.jar b/python/lib/py4j0.7.jar new file mode 100644 index 0000000000..73b7ddb7d1 Binary files /dev/null and b/python/lib/py4j0.7.jar differ diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py new file mode 100644 index 0000000000..c595ae0842 --- /dev/null +++ b/python/pyspark/__init__.py @@ -0,0 +1,20 @@ +""" +PySpark is a Python API for Spark. + +Public classes: + + - L{SparkContext} + Main entry point for Spark functionality. + - L{RDD} + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg")) + + +from pyspark.context import SparkContext +from pyspark.rdd import RDD + + +__all__ = ["SparkContext", "RDD"] diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py new file mode 100644 index 0000000000..93876fa738 --- /dev/null +++ b/python/pyspark/broadcast.py @@ -0,0 +1,48 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> b = sc.broadcast([1, 2, 3, 4, 5]) +>>> b.value +[1, 2, 3, 4, 5] + +>>> from pyspark.broadcast import _broadcastRegistry +>>> _broadcastRegistry[b.bid] = b +>>> from cPickle import dumps, loads +>>> loads(dumps(b)).value +[1, 2, 3, 4, 5] + +>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() +[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + +>>> large_broadcast = sc.broadcast(list(range(10000))) +""" +# Holds broadcasted data received from Java, keyed by its id. +_broadcastRegistry = {} + + +def _from_id(bid): + from pyspark.broadcast import _broadcastRegistry + if bid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % bid) + return _broadcastRegistry[bid] + + +class Broadcast(object): + def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): + self.value = value + self.bid = bid + self._jbroadcast = java_broadcast + self._pickle_registry = pickle_registry + + def __reduce__(self): + self._pickle_registry.add(self) + return (_from_id, (self.bid, )) + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py new file mode 100644 index 0000000000..6a7c23a069 --- /dev/null +++ b/python/pyspark/cloudpickle.py @@ -0,0 +1,974 @@ +""" +This class is defined to override standard pickle functionality + +The goals of it follow: +-Serialize lambdas and nested functions to compiled byte code +-Deal with main module correctly +-Deal with other non-serializable objects + +It does not include an unpickler, as standard python unpickling suffices. + +This module was extracted from the `cloud` package, developed by `PiCloud, Inc. +`_. + +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +import operator +import os +import pickle +import struct +import sys +import types +from functools import partial +import itertools +from copy_reg import _extension_registry, _inverted_registry, _extension_cache +import new +import dis +import traceback + +#relevant opcodes +STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) +DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) +LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] + +HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) +EXTENDED_ARG = chr(dis.EXTENDED_ARG) + +import logging +cloudLog = logging.getLogger("Cloud.Transport") + +try: + import ctypes +except (MemoryError, ImportError): + logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) + ctypes = None + PyObject_HEAD = None +else: + + # for reading internal structures + PyObject_HEAD = [ + ('ob_refcnt', ctypes.c_size_t), + ('ob_type', ctypes.c_void_p), + ] + + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +# These helper functions were copied from PiCloud's util module. +def islambda(func): + return getattr(func,'func_name') == '' + +def xrange_params(xrangeobj): + """Returns a 3 element tuple describing the xrange start, step, and len + respectively + + Note: Only guarentees that elements of xrange are the same. parameters may + be different. + e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same + though w/ iteration + """ + + xrange_len = len(xrangeobj) + if not xrange_len: #empty + return (0,1,0) + start = xrangeobj[0] + if xrange_len == 1: #one element + return start, 1, 1 + return (start, xrangeobj[1] - xrangeobj[0], xrange_len) + +#debug variables intended for developer use: +printSerialization = False +printMemoization = False + +useForcedImports = True #Should I use forced imports for tracking? + + + +class CloudPickler(pickle.Pickler): + + dispatch = pickle.Pickler.dispatch.copy() + savedForceImports = False + savedDjangoEnv = False #hack tro transport django environment + + def __init__(self, file, protocol=None, min_size_to_save= 0): + pickle.Pickler.__init__(self,file,protocol) + self.modules = set() #set of modules needed to depickle + self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env + + def dump(self, obj): + # note: not thread safe + # minimal side-effects, so not fixing + recurse_limit = 3000 + base_recurse = sys.getrecursionlimit() + if base_recurse < recurse_limit: + sys.setrecursionlimit(recurse_limit) + self.inject_addons() + try: + return pickle.Pickler.dump(self, obj) + except RuntimeError, e: + if 'recursion' in e.args[0]: + msg = """Could not pickle object as excessively deep recursion required. + Try _fast_serialization=2 or contact PiCloud support""" + raise pickle.PicklingError(msg) + finally: + new_recurse = sys.getrecursionlimit() + if new_recurse == recurse_limit: + sys.setrecursionlimit(base_recurse) + + def save_buffer(self, obj): + """Fallback to save_string""" + pickle.Pickler.save_string(self,str(obj)) + dispatch[buffer] = save_buffer + + #block broken objects + def save_unsupported(self, obj, pack=None): + raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) + dispatch[types.GeneratorType] = save_unsupported + + #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it + try: + slice(0,1).__reduce__() + except TypeError: #can't pickle - + dispatch[slice] = save_unsupported + + #itertools objects do not pickle! + for v in itertools.__dict__.values(): + if type(v) is type: + dispatch[v] = save_unsupported + + + def save_dict(self, obj): + """hack fix + If the dict is a global, deal with it in a special way + """ + #print 'saving', obj + if obj is __builtins__: + self.save_reduce(_get_module_builtins, (), obj=obj) + else: + pickle.Pickler.save_dict(self, obj) + dispatch[pickle.DictionaryType] = save_dict + + + def save_module(self, obj, pack=struct.pack): + """ + Save a module as an import + """ + #print 'try save import', obj.__name__ + self.modules.add(obj) + self.save_reduce(subimport,(obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module #new type + + def save_codeobject(self, obj, pack=struct.pack): + """ + Save a code object + """ + #print 'try to save codeobj: ', obj + args = ( + obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, + obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars + ) + self.save_reduce(types.CodeType, args, obj=obj) + dispatch[types.CodeType] = save_codeobject #new type + + def save_function(self, obj, name=None, pack=struct.pack): + """ Registered with the dispatch to handle all function types. + + Determines what kind of function obj is (e.g. lambda, defined at + interactive prompt, etc) and handles the pickling appropriately. + """ + write = self.write + + name = obj.__name__ + modname = pickle.whichmodule(obj, name) + #print 'which gives %s %s %s' % (modname, obj, name) + try: + themodule = sys.modules[modname] + except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ + modname = '__main__' + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + if not self.savedDjangoEnv: + #hack for django - if we detect the settings module, we transport it + django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') + if django_settings: + django_mod = sys.modules.get(django_settings) + if django_mod: + cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) + self.savedDjangoEnv = True + self.modules.add(django_mod) + write(pickle.MARK) + self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) + write(pickle.POP_MARK) + + + # if func is lambda, def'ed at prompt, is in main, or is nested, then + # we'll pickle the actual function object rather than simply saving a + # reference (as is done in default pickler), via save_function_tuple. + if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: + #Force server to import modules that have been imported in main + modList = None + if themodule == None and not self.savedForceImports: + mainmod = sys.modules['__main__'] + if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): + modList = list(mainmod.___pyc_forcedImports__) + self.savedForceImports = True + self.save_function_tuple(obj, modList) + return + else: # func is nested + klass = getattr(themodule, name, None) + if klass is None or klass is not obj: + self.save_function_tuple(obj, [themodule]) + return + + if obj.__dict__: + # essentially save_reduce, but workaround needed to avoid recursion + self.save(_restore_attr) + write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + self.save(obj.__dict__) + write(pickle.TUPLE + pickle.REDUCE) + else: + write(pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + dispatch[types.FunctionType] = save_function + + def save_function_tuple(self, func, forced_imports): + """ Pickles an actual func object. + + A func comprises: code, globals, defaults, closure, and dict. We + extract and save these, injecting reducing functions at certain points + to recreate the func object. Keep in mind that some of these pieces + can contain a ref to the func itself. Thus, a naive save on these + pieces could trigger an infinite loop of save's. To get around that, + we first create a skeleton func object using just the code (this is + safe, since this won't contain a ref to the func), and memoize it as + soon as it's created. The other stuff can then be filled in later. + """ + save = self.save + write = self.write + + # save the modules (if any) + if forced_imports: + write(pickle.MARK) + save(_modules_to_main) + #print 'forced imports are', forced_imports + + forced_names = map(lambda m: m.__name__, forced_imports) + save((forced_names,)) + + #save((forced_imports,)) + write(pickle.REDUCE) + write(pickle.POP_MARK) + + code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + + save(_fill_function) # skeleton function updater + write(pickle.MARK) # beginning of tuple that _fill_function expects + + # create a skeleton function object and memoize it + save(_make_skel_func) + save((code, len(closure), base_globals)) + write(pickle.REDUCE) + self.memoize(func) + + # save the rest of the func data needed by _fill_function + save(f_globals) + save(defaults) + save(closure) + save(dct) + write(pickle.TUPLE) + write(pickle.REDUCE) # applies _fill_function on the tuple + + @staticmethod + def extract_code_globals(co): + """ + Find all globals names read or written to by codeblock co + """ + code = co.co_code + names = co.co_names + out_names = set() + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + + i = i+1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + extended_arg = 0 + i = i+2 + if op == EXTENDED_ARG: + extended_arg = oparg*65536L + if op in GLOBAL_OPS: + out_names.add(names[oparg]) + #print 'extracted', out_names, ' from ', names + return out_names + + def extract_func_data(self, func): + """ + Turn the function into a tuple of data necessary to recreate it: + code, globals, defaults, closure, dict + """ + code = func.func_code + + # extract all global ref's + func_global_refs = CloudPickler.extract_code_globals(code) + if code.co_consts: # see if nested function have any global refs + for const in code.co_consts: + if type(const) is types.CodeType and const.co_names: + func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) + # process all variables referenced by global environment + f_globals = {} + for var in func_global_refs: + #Some names, such as class functions are not global - we don't need them + if func.func_globals.has_key(var): + f_globals[var] = func.func_globals[var] + + # defaults requires no processing + defaults = func.func_defaults + + def get_contents(cell): + try: + return cell.cell_contents + except ValueError, e: #cell is empty error on not yet assigned + raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') + + + # process closure + if func.func_closure: + closure = map(get_contents, func.func_closure) + else: + closure = [] + + # save the dict + dct = func.func_dict + + if printSerialization: + outvars = ['code: ' + str(code) ] + outvars.append('globals: ' + str(f_globals)) + outvars.append('defaults: ' + str(defaults)) + outvars.append('closure: ' + str(closure)) + print 'function ', func, 'is extracted to: ', ', '.join(outvars) + + base_globals = self.globals_ref.get(id(func.func_globals), {}) + self.globals_ref[id(func.func_globals)] = base_globals + + return (code, f_globals, defaults, closure, dct, base_globals) + + def save_global(self, obj, name=None, pack=struct.pack): + write = self.write + memo = self.memo + + if name is None: + name = obj.__name__ + + modname = getattr(obj, "__module__", None) + if modname is None: + modname = pickle.whichmodule(obj, name) + + try: + __import__(modname) + themodule = sys.modules[modname] + except (ImportError, KeyError, AttributeError): #should never occur + raise pickle.PicklingError( + "Can't pickle %r: Module %s cannot be found" % + (obj, modname)) + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + sendRef = True + typ = type(obj) + #print 'saving', obj, typ + try: + try: #Deal with case when getattribute fails with exceptions + klass = getattr(themodule, name) + except (AttributeError): + if modname == '__builtin__': #new.* are misrepeported + modname = 'new' + __import__(modname) + themodule = sys.modules[modname] + try: + klass = getattr(themodule, name) + except AttributeError, a: + #print themodule, name, obj, type(obj) + raise pickle.PicklingError("Can't pickle builtin %s" % obj) + else: + raise + + except (ImportError, KeyError, AttributeError): + if typ == types.TypeType or typ == types.ClassType: + sendRef = False + else: #we can't deal with this + raise + else: + if klass is not obj and (typ == types.TypeType or typ == types.ClassType): + sendRef = False + if not sendRef: + #note: Third party types might crash this - add better checks! + d = dict(obj.__dict__) #copy dict proxy to a dict + if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties + d.pop('__dict__',None) + d.pop('__weakref__',None) + + # hack as __new__ is stored differently in the __dict__ + new_override = d.get('__new__', None) + if new_override: + d['__new__'] = obj.__new__ + + self.save_reduce(type(obj),(obj.__name__,obj.__bases__, + d),obj=obj) + #print 'internal reduce dask %s %s' % (obj, d) + return + + if self.proto >= 2: + code = _extension_registry.get((modname, name)) + if code: + assert code > 0 + if code <= 0xff: + write(pickle.EXT1 + chr(code)) + elif code <= 0xffff: + write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) + else: + write(pickle.EXT4 + pack("= 2 and getattr(func, "__name__", "") == "__newobj__": + #Added fix to allow transient + cls = args[0] + if not hasattr(cls, "__new__"): + raise pickle.PicklingError( + "args[0] from __newobj__ args has no __new__") + if obj is not None and cls is not obj.__class__: + raise pickle.PicklingError( + "args[0] from __newobj__ args has the wrong class") + args = args[1:] + save(cls) + + #Don't pickle transient entries + if hasattr(obj, '__transient__'): + transient = obj.__transient__ + state = state.copy() + + for k in list(state.keys()): + if k in transient: + del state[k] + + save(args) + write(pickle.NEWOBJ) + else: + save(func) + save(args) + write(pickle.REDUCE) + + if obj is not None: + self.memoize(obj) + + # More new special cases (that work with older protocols as + # well): when __reduce__ returns a tuple with 4 or 5 items, + # the 4th and 5th item should be iterators that provide list + # items and dict items (as (key, value) tuples), or None. + + if listitems is not None: + self._batch_appends(listitems) + + if dictitems is not None: + self._batch_setitems(dictitems) + + if state is not None: + #print 'obj %s has state %s' % (obj, state) + save(state) + write(pickle.BUILD) + + + def save_xrange(self, obj): + """Save an xrange object in python 2.5 + Python 2.6 supports this natively + """ + range_params = xrange_params(obj) + self.save_reduce(_build_xrange,range_params) + + #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it + try: + xrange(0).__reduce__() + except TypeError: #can't pickle -- use PiCloud pickler + dispatch[xrange] = save_xrange + + def save_partial(self, obj): + """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" + self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) + + if sys.version_info < (2,7): #2.7 supports partial pickling + dispatch[partial] = save_partial + + + def save_file(self, obj): + """Save a file""" + import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute + from ..transport.adapter import SerializingAdapter + + if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): + raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") + if obj.name == '': + return self.save_reduce(getattr, (sys,'stdout'), obj=obj) + if obj.name == '': + return self.save_reduce(getattr, (sys,'stderr'), obj=obj) + if obj.name == '': + raise pickle.PicklingError("Cannot pickle standard input") + if hasattr(obj, 'isatty') and obj.isatty(): + raise pickle.PicklingError("Cannot pickle files that map to tty objects") + if 'r' not in obj.mode: + raise pickle.PicklingError("Cannot pickle files that are not opened for reading") + name = obj.name + try: + fsize = os.stat(name).st_size + except OSError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) + + if obj.closed: + #create an empty closed string io + retval = pystringIO.StringIO("") + retval.close() + elif not fsize: #empty file + retval = pystringIO.StringIO("") + try: + tmpfile = file(name) + tst = tmpfile.read(1) + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + tmpfile.close() + if tst != '': + raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) + elif fsize > SerializingAdapter.max_transmit_data: + raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % + (name,SerializingAdapter.max_transmit_data)) + else: + try: + tmpfile = file(name) + contents = tmpfile.read(SerializingAdapter.max_transmit_data) + tmpfile.close() + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + retval = pystringIO.StringIO(contents) + curloc = obj.tell() + retval.seek(curloc) + + retval.name = name + self.save(retval) #save stringIO + self.memoize(obj) + + dispatch[file] = save_file + """Special functions for Add-on libraries""" + + def inject_numpy(self): + numpy = sys.modules.get('numpy') + if not numpy or not hasattr(numpy, 'ufunc'): + return + self.dispatch[numpy.ufunc] = self.__class__.save_ufunc + + numpy_tst_mods = ['numpy', 'scipy.special'] + def save_ufunc(self, obj): + """Hack function for saving numpy ufunc objects""" + name = obj.__name__ + for tst_mod_name in self.numpy_tst_mods: + tst_mod = sys.modules.get(tst_mod_name, None) + if tst_mod: + if name in tst_mod.__dict__: + self.save_reduce(_getobject, (tst_mod_name, name)) + return + raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj)) + + def inject_timeseries(self): + """Handle bugs with pickling scikits timeseries""" + tseries = sys.modules.get('scikits.timeseries.tseries') + if not tseries or not hasattr(tseries, 'Timeseries'): + return + self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries + + def save_timeseries(self, obj): + import scikits.timeseries.tseries as ts + + func, reduce_args, state = obj.__reduce__() + if func != ts._tsreconstruct: + raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func)) + state = (1, + obj.shape, + obj.dtype, + obj.flags.fnc, + obj._data.tostring(), + ts.getmaskarray(obj).tostring(), + obj._fill_value, + obj._dates.shape, + obj._dates.__array__().tostring(), + obj._dates.dtype, #added -- preserve type + obj.freq, + obj._optinfo, + ) + return self.save_reduce(_genTimeSeries, (reduce_args, state)) + + def inject_email(self): + """Block email LazyImporters from being saved""" + email = sys.modules.get('email') + if not email: + return + self.dispatch[email.LazyImporter] = self.__class__.save_unsupported + + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + self.inject_numpy() + self.inject_timeseries() + self.inject_email() + + """Python Imaging Library""" + def save_image(self, obj): + if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \ + and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()): + #if image not loaded yet -- lazy load + self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj) + else: + #image is loaded - just transmit it over + self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj) + + """ + def memoize(self, obj): + pickle.Pickler.memoize(self, obj) + if printMemoization: + print 'memoizing ' + str(obj) + """ + + + +# Shorthands for legacy support + +def dump(obj, file, protocol=2): + CloudPickler(file, protocol).dump(obj) + +def dumps(obj, protocol=2): + file = StringIO() + + cp = CloudPickler(file,protocol) + cp.dump(obj) + + #print 'cloud dumped', str(obj), str(cp.modules) + + return file.getvalue() + + +#hack for __import__ not working as desired +def subimport(name): + __import__(name) + return sys.modules[name] + +#hack to load django settings: +def django_settings_load(name): + modified_env = False + + if 'DJANGO_SETTINGS_MODULE' not in os.environ: + os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps + modified_env = True + try: + module = subimport(name) + except Exception, i: + print >> sys.stderr, 'Cloud not import django settings %s:' % (name) + print_exec(sys.stderr) + if modified_env: + del os.environ['DJANGO_SETTINGS_MODULE'] + else: + #add project directory to sys,path: + if hasattr(module,'__file__'): + dirname = os.path.split(module.__file__)[0] + '/' + sys.path.append(dirname) + +# restores function attributes +def _restore_attr(obj, attr): + for key, val in attr.items(): + setattr(obj, key, val) + return obj + +def _get_module_builtins(): + return pickle.__builtins__ + +def print_exec(stream): + ei = sys.exc_info() + traceback.print_exception(ei[0], ei[1], ei[2], None, stream) + +def _modules_to_main(modList): + """Force every module in modList to be placed into main""" + if not modList: + return + + main = sys.modules['__main__'] + for modname in modList: + if type(modname) is str: + try: + mod = __import__(modname) + except Exception, i: #catch all... + sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ +A version mismatch is likely. Specific error was:\n' % modname) + print_exec(sys.stderr) + else: + setattr(main,mod.__name__, mod) + else: + #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) + #In old version actual module was sent + setattr(main,modname.__name__, modname) + +#object generators: +def _build_xrange(start, step, len): + """Built xrange explicitly""" + return xrange(start, start + step*len, step) + +def _genpartial(func, args, kwds): + if not args: + args = () + if not kwds: + kwds = {} + return partial(func, *args, **kwds) + + +def _fill_function(func, globals, defaults, closure, dict): + """ Fills in the rest of function data into the skeleton function object + that were created via _make_skel_func(). + """ + func.func_globals.update(globals) + func.func_defaults = defaults + func.func_dict = dict + + if len(closure) != len(func.func_closure): + raise pickle.UnpicklingError("closure lengths don't match up") + for i in range(len(closure)): + _change_cell_value(func.func_closure[i], closure[i]) + + return func + +def _make_skel_func(code, num_closures, base_globals = None): + """ Creates a skeleton function object that contains just the provided + code and the correct number of cells in func_closure. All other + func attributes (e.g. func_globals) are empty. + """ + #build closure (cells): + if not ctypes: + raise Exception('ctypes failed to import; cannot build function') + + cellnew = ctypes.pythonapi.PyCell_New + cellnew.restype = ctypes.py_object + cellnew.argtypes = (ctypes.py_object,) + dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) + + if base_globals is None: + base_globals = {} + base_globals['__builtins__'] = __builtins__ + + return types.FunctionType(code, base_globals, + None, None, dummy_closure) + +# this piece of opaque code is needed below to modify 'cell' contents +cell_changer_code = new.code( + 1, 1, 2, 0, + ''.join([ + chr(dis.opmap['LOAD_FAST']), '\x00\x00', + chr(dis.opmap['DUP_TOP']), + chr(dis.opmap['STORE_DEREF']), '\x00\x00', + chr(dis.opmap['RETURN_VALUE']) + ]), + (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () +) + +def _change_cell_value(cell, newval): + """ Changes the contents of 'cell' object to newval """ + return new.function(cell_changer_code, {}, None, (), (cell,))(newval) + +"""Constructors for 3rd party libraries +Note: These can never be renamed due to client compatibility issues""" + +def _getobject(modname, attribute): + mod = __import__(modname) + return mod.__dict__[attribute] + +def _generateImage(size, mode, str_rep): + """Generate image from string representation""" + import Image + i = Image.new(mode, size) + i.fromstring(str_rep) + return i + +def _lazyloadImage(fp): + import Image + fp.seek(0) #works in almost any case + return Image.open(fp) + +"""Timeseries""" +def _genTimeSeries(reduce_args, state): + import scikits.timeseries.tseries as ts + from numpy import ndarray + from numpy.ma import MaskedArray + + + time_series = ts._tsreconstruct(*reduce_args) + + #from setstate modified + (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state + #print 'regenerating %s' % dtyp + + MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) + _dates = time_series._dates + #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ + ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) + _dates.freq = frq + _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, + toobj=None, toord=None, tostr=None)) + # Update the _optinfo dictionary + time_series._optinfo.update(infodict) + return time_series + diff --git a/python/pyspark/context.py b/python/pyspark/context.py new file mode 100644 index 0000000000..6172d69dcf --- /dev/null +++ b/python/pyspark/context.py @@ -0,0 +1,158 @@ +import os +import atexit +from tempfile import NamedTemporaryFile + +from pyspark.broadcast import Broadcast +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.rdd import RDD + +from py4j.java_collections import ListConverter + + +class SparkContext(object): + """ + Main entry point for Spark functionality. A SparkContext represents the + connection to a Spark cluster, and can be used to create L{RDD}s and + broadcast variables on that cluster. + """ + + gateway = launch_gateway() + jvm = gateway.jvm + _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile + _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile + + def __init__(self, master, jobName, sparkHome=None, pyFiles=None, + environment=None, batchSize=1024): + """ + Create a new SparkContext. + + @param master: Cluster URL to connect to + (e.g. mesos://host:port, spark://host:port, local[4]). + @param jobName: A name for your job, to display on the cluster web UI + @param sparkHome: Location where Spark is installed on cluster nodes. + @param pyFiles: Collection of .zip or .py files to send to the cluster + and add to PYTHONPATH. These can be paths on the local file + system or HDFS, HTTP, HTTPS, or FTP URLs. + @param environment: A dictionary of environment variables to set on + worker nodes. + @param batchSize: The number of Python objects represented as a single + Java object. Set 1 to disable batching or -1 to use an + unlimited batch size. + """ + self.master = master + self.jobName = jobName + self.sparkHome = sparkHome or None # None becomes null in Py4J + self.environment = environment or {} + self.batchSize = batchSize # -1 represents a unlimited batch size + + # Create the Java SparkContext through Py4J + empty_string_array = self.gateway.new_array(self.jvm.String, 0) + self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, + empty_string_array) + + self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') + # Broadcast's __reduce__ method stores Broadcast instances here. + # This allows other code to determine which Broadcast instances have + # been pickled, so it can determine which Java broadcast objects to + # send. + self._pickled_broadcast_vars = set() + + # Deploy any code dependencies specified in the constructor + for path in (pyFiles or []): + self.addPyFile(path) + + @property + def defaultParallelism(self): + """ + Default level of parallelism to use when not given by user (e.g. for + reduce tasks) + """ + return self._jsc.sc().defaultParallelism() + + def __del__(self): + if self._jsc: + self._jsc.stop() + + def stop(self): + """ + Shut down the SparkContext. + """ + self._jsc.stop() + self._jsc = None + + def parallelize(self, c, numSlices=None): + """ + Distribute a local Python collection to form an RDD. + """ + numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). + tempFile = NamedTemporaryFile(delete=False) + atexit.register(lambda: os.unlink(tempFile.name)) + if self.batchSize != 1: + c = batched(c, self.batchSize) + for x in c: + write_with_length(dump_pickle(x), tempFile) + tempFile.close() + jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + return RDD(jrdd, self) + + def textFile(self, name, minSplits=None): + """ + Read a text file from HDFS, a local file system (available on all + nodes), or any Hadoop-supported file system URI, and return it as an + RDD of Strings. + """ + minSplits = minSplits or min(self.defaultParallelism, 2) + jrdd = self._jsc.textFile(name, minSplits) + return RDD(jrdd, self) + + def union(self, rdds): + """ + Build the union of a list of RDDs. + """ + first = rdds[0]._jrdd + rest = [x._jrdd for x in rdds[1:]] + rest = ListConverter().convert(rest, self.gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self) + + def broadcast(self, value): + """ + Broadcast a read-only variable to the cluster, returning a C{Broadcast} + object for reading it in distributed functions. The variable will be + sent to each cluster only once. + """ + jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) + return Broadcast(jbroadcast.id(), value, jbroadcast, + self._pickled_broadcast_vars) + + def addFile(self, path): + """ + Add a file to be downloaded into the working directory of this Spark + job on every node. The C{path} passed can be either a local file, + a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, + HTTPS or FTP URI. + """ + self._jsc.sc().addFile(path) + + def clearFiles(self): + """ + Clear the job's list of files added by L{addFile} or L{addPyFile} so + that they do not get downloaded to any new nodes. + """ + # TODO: remove added .py or .zip files from the PYTHONPATH? + self._jsc.sc().clearFiles() + + def addPyFile(self, path): + """ + Add a .py or .zip dependency for all tasks to be executed on this + SparkContext in the future. The C{path} passed can be either a local + file, a file in HDFS (or other Hadoop-supported filesystems), or an + HTTP, HTTPS or FTP URI. + """ + self.addFile(path) + filename = path.split("/")[-1] + os.environ["PYTHONPATH"] = \ + "%s:%s" % (filename, os.environ["PYTHONPATH"]) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py new file mode 100644 index 0000000000..2329e536cc --- /dev/null +++ b/python/pyspark/java_gateway.py @@ -0,0 +1,38 @@ +import os +import sys +from subprocess import Popen, PIPE +from threading import Thread +from py4j.java_gateway import java_import, JavaGateway, GatewayClient + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +def launch_gateway(): + # Launch the Py4j gateway using Spark's run command so that we pick up the + # proper classpath and SPARK_MEM settings from spark-env.sh + command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer", + "--die-on-broken-pipe", "0"] + proc = Popen(command, stdout=PIPE, stdin=PIPE) + # Determine which ephemeral port the server started on: + port = int(proc.stdout.readline()) + # Create a thread to echo output from the GatewayServer, which is required + # for Java log output to show up: + class EchoOutputThread(Thread): + def __init__(self, stream): + Thread.__init__(self) + self.daemon = True + self.stream = stream + + def run(self): + while True: + line = self.stream.readline() + sys.stderr.write(line) + EchoOutputThread(proc.stdout).start() + # Connect to the gateway + gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) + # Import the classes used by PySpark + java_import(gateway.jvm, "spark.api.java.*") + java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "scala.Tuple2") + return gateway diff --git a/python/pyspark/join.py b/python/pyspark/join.py new file mode 100644 index 0000000000..7036c47980 --- /dev/null +++ b/python/pyspark/join.py @@ -0,0 +1,92 @@ +""" +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +def _do_python_join(rdd, other, numSplits, dispatch): + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch) + + +def python_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numSplits, dispatch) + + +def python_right_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not vbuf: + vbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numSplits, dispatch) + + +def python_left_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not wbuf: + wbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numSplits, dispatch) + + +def python_cogroup(rdd, other, numSplits): + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return (vbuf, wbuf) + return vs.union(ws).groupByKey(numSplits).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py new file mode 100644 index 0000000000..cbffb6cc1f --- /dev/null +++ b/python/pyspark/rdd.py @@ -0,0 +1,713 @@ +import atexit +from base64 import standard_b64encode as b64enc +import copy +from collections import defaultdict +from itertools import chain, ifilter, imap, product +import operator +import os +import shlex +from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile +from threading import Thread + +from pyspark import cloudpickle +from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ + read_from_pickle_file +from pyspark.join import python_join, python_left_outer_join, \ + python_right_outer_join, python_cogroup + +from py4j.java_collections import ListConverter, MapConverter + + +__all__ = ["RDD"] + + +class RDD(object): + """ + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + Represents an immutable, partitioned collection of elements that can be + operated on in parallel. + """ + + def __init__(self, jrdd, ctx): + self._jrdd = jrdd + self.is_cached = False + self.ctx = ctx + + @property + def context(self): + """ + The L{SparkContext} that this RDD was created on. + """ + return self.ctx + + def cache(self): + """ + Persist this RDD with the default storage level (C{MEMORY_ONLY}). + """ + self.is_cached = True + self._jrdd.cache() + return self + + # TODO persist(self, storageLevel) + + def map(self, f, preservesPartitioning=False): + """ + Return a new RDD containing the distinct elements in this RDD. + """ + def func(iterator): return imap(f, iterator) + return PipelinedRDD(self, func, preservesPartitioning) + + def flatMap(self, f, preservesPartitioning=False): + """ + Return a new RDD by first applying a function to all elements of this + RDD, and then flattening the results. + + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) + [1, 1, 1, 2, 2, 3] + >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) + [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] + """ + def func(iterator): return chain.from_iterable(imap(f, iterator)) + return self.mapPartitions(func, preservesPartitioning) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> def f(iterator): yield sum(iterator) + >>> rdd.mapPartitions(f).collect() + [3, 7] + """ + return PipelinedRDD(self, f, preservesPartitioning) + + # TODO: mapPartitionsWithSplit + + def filter(self, f): + """ + Return a new RDD containing only the elements that satisfy a predicate. + + >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) + >>> rdd.filter(lambda x: x % 2 == 0).collect() + [2, 4] + """ + def func(iterator): return ifilter(f, iterator) + return self.mapPartitions(func) + + def distinct(self): + """ + Return a new RDD containing the distinct elements in this RDD. + + >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) + [1, 2, 3] + """ + return self.map(lambda x: (x, "")) \ + .reduceByKey(lambda x, _: x) \ + .map(lambda (x, _): x) + + # TODO: sampling needs to be re-implemented due to Batch + #def sample(self, withReplacement, fraction, seed): + # jrdd = self._jrdd.sample(withReplacement, fraction, seed) + # return RDD(jrdd, self.ctx) + + #def takeSample(self, withReplacement, num, seed): + # vals = self._jrdd.takeSample(withReplacement, num, seed) + # return [load_pickle(bytes(x)) for x in vals] + + def union(self, other): + """ + Return the union of this RDD and another one. + + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> rdd.union(rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + return RDD(self._jrdd.union(other._jrdd), self.ctx) + + def __add__(self, other): + """ + Return the union of this RDD and another one. + + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> (rdd + rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + if not isinstance(other, RDD): + raise TypeError + return self.union(other) + + # TODO: sort + + def glom(self): + """ + Return an RDD created by coalescing all elements within each partition + into a list. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> sorted(rdd.glom().collect()) + [[1, 2], [3, 4]] + """ + def func(iterator): yield list(iterator) + return self.mapPartitions(func) + + def cartesian(self, other): + """ + Return the Cartesian product of this RDD and another one, that is, the + RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and + C{b} is in C{other}. + + >>> rdd = sc.parallelize([1, 2]) + >>> sorted(rdd.cartesian(rdd).collect()) + [(1, 1), (1, 2), (2, 1), (2, 2)] + """ + # Due to batching, we can't use the Java cartesian method. + java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + def unpack_batches(pair): + (x, y) = pair + if type(x) == Batch or type(y) == Batch: + xs = x.items if type(x) == Batch else [x] + ys = y.items if type(y) == Batch else [y] + for pair in product(xs, ys): + yield pair + else: + yield pair + return java_cartesian.flatMap(unpack_batches) + + def groupBy(self, f, numSplits=None): + """ + Return an RDD of grouped items. + + >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) + >>> result = rdd.groupBy(lambda x: x % 2).collect() + >>> sorted([(x, sorted(y)) for (x, y) in result]) + [(0, [2, 8]), (1, [1, 1, 3, 5])] + """ + return self.map(lambda x: (f(x), x)).groupByKey(numSplits) + + def pipe(self, command, env={}): + """ + Return an RDD created by piping elements to a forked external process. + + >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() + ['1', '2', '3'] + """ + def func(iterator): + pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) + def pipe_objs(out): + for obj in iterator: + out.write(str(obj).rstrip('\n') + '\n') + out.close() + Thread(target=pipe_objs, args=[pipe.stdin]).start() + return (x.rstrip('\n') for x in pipe.stdout) + return self.mapPartitions(func) + + def foreach(self, f): + """ + Applies a function to all elements of this RDD. + + >>> def f(x): print x + >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) + """ + self.map(f).collect() # Force evaluation + + def collect(self): + """ + Return a list that contains all of the elements in this RDD. + """ + picklesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(picklesInJava)) + + def _collect_iterator_through_file(self, iterator): + # Transferring lots of data through Py4J can be slow because + # socket.readline() is inefficient. Instead, we'll dump the data to a + # file and read it back. + tempFile = NamedTemporaryFile(delete=False) + tempFile.close() + def clean_up_file(): + try: os.unlink(tempFile.name) + except: pass + atexit.register(clean_up_file) + self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + for item in read_from_pickle_file(tempFile): + yield item + os.unlink(tempFile.name) + + def reduce(self, f): + """ + Reduces the elements of this RDD using the specified associative binary + operator. + + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) + 15 + >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) + 10 + """ + def func(iterator): + acc = None + for obj in iterator: + if acc is None: + acc = obj + else: + acc = f(obj, acc) + if acc is not None: + yield acc + vals = self.mapPartitions(func).collect() + return reduce(f, vals) + + def fold(self, zeroValue, op): + """ + Aggregate the elements of each partition, and then the results for all + the partitions, using a given associative function and a neutral "zero + value." + + The function C{op(t1, t2)} is allowed to modify C{t1} and return it + as its result value to avoid object allocation; however, it should not + modify C{t2}. + + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) + 15 + """ + def func(iterator): + acc = zeroValue + for obj in iterator: + acc = op(obj, acc) + yield acc + vals = self.mapPartitions(func).collect() + return reduce(op, vals, zeroValue) + + # TODO: aggregate + + def sum(self): + """ + Add up the elements in this RDD. + + >>> sc.parallelize([1.0, 2.0, 3.0]).sum() + 6.0 + """ + return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + + def count(self): + """ + Return the number of elements in this RDD. + + >>> sc.parallelize([2, 3, 4]).count() + 3 + """ + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() + + def countByValue(self): + """ + Return the count of each unique value in this RDD as a dictionary of + (value, count) pairs. + + >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items()) + [(1, 2), (2, 3)] + """ + def countPartition(iterator): + counts = defaultdict(int) + for obj in iterator: + counts[obj] += 1 + yield counts + def mergeMaps(m1, m2): + for (k, v) in m2.iteritems(): + m1[k] += v + return m1 + return self.mapPartitions(countPartition).reduce(mergeMaps) + + def take(self, num): + """ + Take the first num elements of the RDD. + + This currently scans the partitions *one by one*, so it will be slow if + a lot of partitions are required. In that case, use L{collect} to get + the whole RDD instead. + + >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) + [2, 3] + >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) + [2, 3, 4, 5, 6] + """ + items = [] + splits = self._jrdd.splits() + taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) + while len(items) < num and splits: + split = splits.pop(0) + iterator = self._jrdd.iterator(split, taskContext) + items.extend(self._collect_iterator_through_file(iterator)) + return items[:num] + + def first(self): + """ + Return the first element in this RDD. + + >>> sc.parallelize([2, 3, 4]).first() + 2 + """ + return self.take(1)[0] + + def saveAsTextFile(self, path): + """ + Save this RDD as a text file, using string representations of elements. + + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.close() + >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) + >>> from fileinput import input + >>> from glob import glob + >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) + '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' + """ + def func(iterator): + return (str(x).encode("utf-8") for x in iterator) + keyed = PipelinedRDD(self, func) + keyed._bypass_serializer = True + keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) + + # Pair functions + + def collectAsMap(self): + """ + Return the key-value pairs in this RDD to the master as a dictionary. + + >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + return dict(self.collect()) + + def reduceByKey(self, func, numSplits=None): + """ + Merge the values for each key using an associative reduce function. + + This will also perform the merging locally on each mapper before + sending results to a reducer, similarly to a "combiner" in MapReduce. + + Output will be hash-partitioned with C{numSplits} splits, or the + default parallelism level if C{numSplits} is not specified. + + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKey(add).collect()) + [('a', 2), ('b', 1)] + """ + return self.combineByKey(lambda x: x, func, func, numSplits) + + def reduceByKeyLocally(self, func): + """ + Merge the values for each key using an associative reduce function, but + return the results immediately to the master as a dictionary. + + This will also perform the merging locally on each mapper before + sending results to a reducer, similarly to a "combiner" in MapReduce. + + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKeyLocally(add).items()) + [('a', 2), ('b', 1)] + """ + def reducePartition(iterator): + m = {} + for (k, v) in iterator: + m[k] = v if k not in m else func(m[k], v) + yield m + def mergeMaps(m1, m2): + for (k, v) in m2.iteritems(): + m1[k] = v if k not in m1 else func(m1[k], v) + return m1 + return self.mapPartitions(reducePartition).reduce(mergeMaps) + + def countByKey(self): + """ + Count the number of elements for each key, and return the result to the + master as a dictionary. + + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.countByKey().items()) + [('a', 2), ('b', 1)] + """ + return self.map(lambda x: x[0]).countByValue() + + def join(self, other, numSplits=None): + """ + Return an RDD containing all pairs of elements with matching keys in + C{self} and C{other}. + + Each pair of elements will be returned as a (k, (v1, v2)) tuple, where + (k, v1) is in C{self} and (k, v2) is in C{other}. + + Performs a hash join across the cluster. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2), ("a", 3)]) + >>> sorted(x.join(y).collect()) + [('a', (1, 2)), ('a', (1, 3))] + """ + return python_join(self, other, numSplits) + + def leftOuterJoin(self, other, numSplits=None): + """ + Perform a left outer join of C{self} and C{other}. + + For each element (k, v) in C{self}, the resulting RDD will either + contain all pairs (k, (v, w)) for w in C{other}, or the pair + (k, (v, None)) if no elements in other have key k. + + Hash-partitions the resulting RDD into the given number of partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> sorted(x.leftOuterJoin(y).collect()) + [('a', (1, 2)), ('b', (4, None))] + """ + return python_left_outer_join(self, other, numSplits) + + def rightOuterJoin(self, other, numSplits=None): + """ + Perform a right outer join of C{self} and C{other}. + + For each element (k, w) in C{other}, the resulting RDD will either + contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w)) + if no elements in C{self} have key k. + + Hash-partitions the resulting RDD into the given number of partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> sorted(y.rightOuterJoin(x).collect()) + [('a', (2, 1)), ('b', (None, 4))] + """ + return python_right_outer_join(self, other, numSplits) + + # TODO: add option to control map-side combining + def partitionBy(self, numSplits, hashFunc=hash): + """ + Return a copy of the RDD partitioned using the specified partitioner. + + >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) + >>> sets = pairs.partitionBy(2).glom().collect() + >>> set(sets[0]).intersection(set(sets[1])) + set([]) + """ + if numSplits is None: + numSplits = self.ctx.defaultParallelism + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. + def add_shuffle_key(iterator): + buckets = defaultdict(list) + for (k, v) in iterator: + buckets[hashFunc(k) % numSplits].append((k, v)) + for (split, items) in buckets.iteritems(): + yield str(split) + yield dump_pickle(Batch(items)) + keyed = PipelinedRDD(self, add_shuffle_key) + keyed._bypass_serializer = True + pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) + jrdd = pairRDD.partitionBy(partitioner) + jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) + return RDD(jrdd, self.ctx) + + # TODO: add control over map-side aggregation + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numSplits=None): + """ + Generic function to combine the elements for each key using a custom + set of aggregation functions. + + Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined + type" C. Note that V and C can be different -- for example, one might + group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). + + Users provide three functions: + + - C{createCombiner}, which turns a V into a C (e.g., creates + a one-element list) + - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of + a list) + - C{mergeCombiners}, to combine two C's into a single one. + + In addition, users can control the partitioning of the output RDD. + + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> def f(x): return x + >>> def add(a, b): return a + str(b) + >>> sorted(x.combineByKey(str, add, add).collect()) + [('a', '11'), ('b', '1')] + """ + if numSplits is None: + numSplits = self.ctx.defaultParallelism + def combineLocally(iterator): + combiners = {} + for (k, v) in iterator: + if k not in combiners: + combiners[k] = createCombiner(v) + else: + combiners[k] = mergeValue(combiners[k], v) + return combiners.iteritems() + locally_combined = self.mapPartitions(combineLocally) + shuffled = locally_combined.partitionBy(numSplits) + def _mergeCombiners(iterator): + combiners = {} + for (k, v) in iterator: + if not k in combiners: + combiners[k] = v + else: + combiners[k] = mergeCombiners(combiners[k], v) + return combiners.iteritems() + return shuffled.mapPartitions(_mergeCombiners) + + # TODO: support variant with custom partitioner + def groupByKey(self, numSplits=None): + """ + Group the values for each key in the RDD into a single sequence. + Hash-partitions the resulting RDD with into numSplits partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.groupByKey().collect()) + [('a', [1, 1]), ('b', [1])] + """ + + def createCombiner(x): + return [x] + + def mergeValue(xs, x): + xs.append(x) + return xs + + def mergeCombiners(a, b): + return a + b + + return self.combineByKey(createCombiner, mergeValue, mergeCombiners, + numSplits) + + # TODO: add tests + def flatMapValues(self, f): + """ + Pass each value in the key-value pair RDD through a flatMap function + without changing the keys; this also retains the original RDD's + partitioning. + """ + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMap(flat_map_fn, preservesPartitioning=True) + + def mapValues(self, f): + """ + Pass each value in the key-value pair RDD through a map function + without changing the keys; this also retains the original RDD's + partitioning. + """ + map_values_fn = lambda (k, v): (k, f(v)) + return self.map(map_values_fn, preservesPartitioning=True) + + # TODO: support varargs cogroup of several RDDs. + def groupWith(self, other): + """ + Alias for cogroup. + """ + return self.cogroup(other) + + # TODO: add variant with custom parittioner + def cogroup(self, other, numSplits=None): + """ + For each key k in C{self} or C{other}, return a resulting RDD that + contains a tuple with the list of values for that key in C{self} as well + as C{other}. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> sorted(x.cogroup(y).collect()) + [('a', ([1], [2])), ('b', ([4], []))] + """ + return python_cogroup(self, other, numSplits) + + # TODO: `lookup` is disabled because we can't make direct comparisons based + # on the key; we need to compare the hash of the key to the hash of the + # keys in the pairs. This could be an expensive operation, since those + # hashes aren't retained. + + +class PipelinedRDD(RDD): + """ + Pipelined maps: + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + + Pipelined reduces: + >>> from operator import add + >>> rdd.map(lambda x: 2 * x).reduce(add) + 20 + >>> rdd.flatMap(lambda x: [x, x]).reduce(add) + 20 + """ + def __init__(self, prev, func, preservesPartitioning=False): + if isinstance(prev, PipelinedRDD) and not prev.is_cached: + prev_func = prev.func + def pipeline_func(iterator): + return func(prev_func(iterator)) + self.func = pipeline_func + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jrdd = prev._prev_jrdd + else: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self.is_cached = False + self.ctx = prev.ctx + self.prev = prev + self._jrdd_val = None + self._bypass_serializer = False + + @property + def _jrdd(self): + if self._jrdd_val: + return self._jrdd_val + func = self.func + if not self._bypass_serializer and self.ctx.batchSize != 1: + oldfunc = self.func + batchSize = self.ctx.batchSize + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) + func = batched_func + cmds = [func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], + self.ctx.gateway._gateway_client) + self.ctx._pickled_broadcast_vars.clear() + class_manifest = self._prev_jrdd.classManifest() + env = copy.copy(self.ctx.environment) + env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") + env = MapConverter().convert(env, self.ctx.gateway._gateway_client) + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, + broadcast_vars, class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() + return self._jrdd_val + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py new file mode 100644 index 0000000000..9a5151ea00 --- /dev/null +++ b/python/pyspark/serializers.py @@ -0,0 +1,78 @@ +import struct +import cPickle + + +class Batch(object): + """ + Used to store multiple RDD entries as a single Java object. + + This relieves us from having to explicitly track whether an RDD + is stored as batches of objects and avoids problems when processing + the union() of batched and unbatched RDDs (e.g. the union() of textFile() + with another RDD). + """ + def __init__(self, items): + self.items = items + + +def batched(iterator, batchSize): + if batchSize == -1: # unlimited batch size + yield Batch(list(iterator)) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = 0 + if items: + yield Batch(items) + + +def dump_pickle(obj): + return cPickle.dumps(obj, 2) + + +load_pickle = cPickle.loads + + +def read_long(stream): + length = stream.read(8) + if length == "": + raise EOFError + return struct.unpack("!q", length)[0] + + +def read_int(stream): + length = stream.read(4) + if length == "": + raise EOFError + return struct.unpack("!i", length)[0] + +def write_with_length(obj, stream): + stream.write(struct.pack("!i", len(obj))) + stream.write(obj) + + +def read_with_length(stream): + length = read_int(stream) + obj = stream.read(length) + if obj == "": + raise EOFError + return obj + + +def read_from_pickle_file(stream): + try: + while True: + obj = load_pickle(read_with_length(stream)) + if type(obj) == Batch: # We don't care about inheritance + for item in obj.items: + yield item + else: + yield obj + except EOFError: + return diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py new file mode 100644 index 0000000000..bd39b0283f --- /dev/null +++ b/python/pyspark/shell.py @@ -0,0 +1,33 @@ +""" +An interactive shell. +""" +import optparse # I prefer argparse, but it's not included with Python < 2.7 +import code +import sys + +from pyspark.context import SparkContext + + +def main(master='local', ipython=False): + sc = SparkContext(master, 'PySparkShell') + user_ns = {'sc' : sc} + banner = "Spark context avaiable as sc." + if ipython: + import IPython + IPython.embed(user_ns=user_ns, banner2=banner) + else: + print banner + code.interact(local=user_ns) + + +if __name__ == '__main__': + usage = "usage: %prog [options] master" + parser = optparse.OptionParser(usage=usage) + parser.add_option("-i", "--ipython", help="Run IPython shell", + action="store_true") + (options, args) = parser.parse_args() + if len(sys.argv) > 1: + master = args[0] + else: + master = 'local' + main(master, options.ipython) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py new file mode 100644 index 0000000000..9f6b507dbd --- /dev/null +++ b/python/pyspark/worker.py @@ -0,0 +1,40 @@ +""" +Worker that receives input from Piped RDD. +""" +import sys +from base64 import standard_b64decode +# CloudPickler needs to be imported so that depicklers are registered using the +# copy_reg module. +from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import write_with_length, read_with_length, \ + read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file + + +# Redirect stdout to stderr so that users must return values from functions. +old_stdout = sys.stdout +sys.stdout = sys.stderr + + +def load_obj(): + return load_pickle(standard_b64decode(sys.stdin.readline().strip())) + + +def main(): + num_broadcast_variables = read_int(sys.stdin) + for _ in range(num_broadcast_variables): + bid = read_long(sys.stdin) + value = read_with_length(sys.stdin) + _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + func = load_obj() + bypassSerializer = load_obj() + if bypassSerializer: + dumps = lambda x: x + else: + dumps = dump_pickle + for obj in func(read_from_pickle_file(sys.stdin)): + write_with_length(dumps(obj), old_stdout) + + +if __name__ == '__main__': + main() diff --git a/run b/run index ed788c4db3..08e2b2434b 100755 --- a/run +++ b/run @@ -63,7 +63,7 @@ CORE_DIR="$FWDIR/core" REPL_DIR="$FWDIR/repl" EXAMPLES_DIR="$FWDIR/examples" BAGEL_DIR="$FWDIR/bagel" -PYSPARK_DIR="$FWDIR/pyspark" +PYSPARK_DIR="$FWDIR/python" # Build up classpath CLASSPATH="$SPARK_CLASSPATH" diff --git a/run-pyspark b/run-pyspark new file mode 100755 index 0000000000..deb0d708b3 --- /dev/null +++ b/run-pyspark @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" + +# Load environment variables from conf/spark-env.sh, if it exists +if [ -e $FWDIR/conf/spark-env.sh ] ; then + . $FWDIR/conf/spark-env.sh +fi + +# Figure out which Python executable to use +if [ -z "$PYSPARK_PYTHON" ] ; then + PYSPARK_PYTHON="python" +fi +export PYSPARK_PYTHON + +# Add the PySpark classes to the Python path: +export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH + +# Launch with `scala` by default: +if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then + export SPARK_LAUNCH_WITH_SCALA=1 +fi + +exec "$PYSPARK_PYTHON" "$@" diff --git a/run2.cmd b/run2.cmd index 9c50804e69..83464b1166 100644 --- a/run2.cmd +++ b/run2.cmd @@ -34,7 +34,7 @@ set CORE_DIR=%FWDIR%core set REPL_DIR=%FWDIR%repl set EXAMPLES_DIR=%FWDIR%examples set BAGEL_DIR=%FWDIR%bagel -set PYSPARK_DIR=%FWDIR%pyspark +set PYSPARK_DIR=%FWDIR%python rem Build up classpath set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes -- cgit v1.2.3 From 3dc87dd923578f20f2c6945be7d8c03797e76237 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Jan 2013 16:38:04 -0800 Subject: Fixed compilation bug in RDDSuite created during merge for mesos/master. --- core/src/test/scala/spark/RDDSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index eab09956bb..e5a59dc7d6 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -105,9 +105,9 @@ class RDDSuite extends FunSuite with BeforeAndAfter { sc = new SparkContext("local", "test") val onlySplit = new Split { override def index: Int = 0 } var shouldFail = true - val rdd = new RDD[Int](sc) { - override def splits: Array[Split] = Array(onlySplit) - override val dependencies = List[Dependency[_]]() + val rdd = new RDD[Int](sc, Nil) { + override def getSplits: Array[Split] = Array(onlySplit) + override val getDependencies = List[Dependency[_]]() override def compute(split: Split, context: TaskContext): Iterator[Int] = { if (shouldFail) { throw new Exception("injected failure") -- cgit v1.2.3 From ce9f1bbe20eff794cd1d588dc88f109d32588cfe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Jan 2013 21:25:49 -0800 Subject: Add `pyspark` script to replace the other scripts. Expand the PySpark programming guide. --- docs/python-programming-guide.md | 49 ++++++++++++++++++++++++++++++++++++---- docs/quick-start.md | 4 ++-- pyspark | 32 ++++++++++++++++++++++++++ pyspark-shell | 3 --- python/pyspark/shell.py | 36 ++++++++--------------------- python/run-tests | 9 ++++++++ run-pyspark | 28 ----------------------- 7 files changed, 97 insertions(+), 64 deletions(-) create mode 100755 pyspark delete mode 100755 pyspark-shell create mode 100755 python/run-tests delete mode 100755 run-pyspark diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index d88d4eb42d..d963551296 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -24,6 +24,35 @@ There are a few key differences between the Python and Scala APIs: - `sample` - `sort` +In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types. +Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax: + +{% highlight python %} +logData = sc.textFile(logFile).cache() +errors = logData.filter(lambda s: 'ERROR' in s.split()) +{% endhighlight %} + +You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`: + +{% highlight python %} +def is_error(line): + return 'ERROR' in line.split() +errors = logData.filter(is_error) +{% endhighlight %} + +Functions can access objects in enclosing scopes, although modifications to those objects within RDD methods will not be propagated to other tasks: + +{% highlight python %} +error_keywords = ["Exception", "Error"] +def is_error(line): + words = line.split() + return any(keyword in words for keyword in error_keywords) +errors = logData.filter(is_error) +{% endhighlight %} + +PySpark will automatically ship these functions to workers, along with any objects that they reference. +Instances of classes will be serialized and shipped to workers by PySpark, but classes themselves cannot be automatically distributed to workers. +The [Standalone Use](#standalone-use) section describes how to ship code dependencies to workers. # Installing and Configuring PySpark @@ -34,13 +63,14 @@ By default, PySpark's scripts will run programs using `python`; an alternate Pyt All of PySpark's library dependencies, including [Py4J](http://py4j.sourceforge.net/), are bundled with PySpark and automatically imported. -Standalone PySpark jobs should be run using the `run-pyspark` script, which automatically configures the Java and Python environmnt using the settings in `conf/spark-env.sh`. +Standalone PySpark jobs should be run using the `pyspark` script, which automatically configures the Java and Python environment using the settings in `conf/spark-env.sh`. The script automatically adds the `pyspark` package to the `PYTHONPATH`. # Interactive Use -PySpark's `pyspark-shell` script provides a simple way to learn the API: +The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs. +When run without any input files, `pyspark` launches a shell that can be used explore data interactively, which is a simple way to learn the API: {% highlight python %} >>> words = sc.textFile("/usr/share/dict/words") @@ -48,9 +78,18 @@ PySpark's `pyspark-shell` script provides a simple way to learn the API: [u'spar', u'sparable', u'sparada', u'sparadrap', u'sparagrass'] {% endhighlight %} +By default, the `pyspark` shell creates SparkContext that runs jobs locally. +To connect to a non-local cluster, set the `MASTER` environment variable. +For example, to use the `pyspark` shell with a [standalone Spark cluster](spark-standalone.html): + +{% highlight shell %} +$ MASTER=spark://IP:PORT ./pyspark +{% endhighlight %} + + # Standalone Use -PySpark can also be used from standalone Python scripts by creating a SparkContext in the script and running the script using the `run-pyspark` script in the `pyspark` directory. +PySpark can also be used from standalone Python scripts by creating a SparkContext in your script and running the script using `pyspark`. The Quick Start guide includes a [complete example](quick-start.html#a-standalone-job-in-python) of a standalone Python job. Code dependencies can be deployed by listing them in the `pyFiles` option in the SparkContext constructor: @@ -65,8 +104,8 @@ Code dependencies can be added to an existing SparkContext using its `addPyFile( # Where to Go from Here -PySpark includes several sample programs using the Python API in `pyspark/examples`. -You can run them by passing the files to the `pyspark-run` script included in PySpark -- for example `./pyspark-run examples/wordcount.py`. +PySpark includes several sample programs using the Python API in `python/examples`. +You can run them by passing the files to the `pyspark` script -- for example `./pyspark python/examples/wordcount.py`. Each example program prints usage help when run without any arguments. We currently provide [API documentation](api/pyspark/index.html) for the Python API as Epydoc. diff --git a/docs/quick-start.md b/docs/quick-start.md index 8c25df5486..2c7cfbed25 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -258,11 +258,11 @@ We can pass Python functions to Spark, which are automatically serialized along For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide). `SimpleJob` is simple enough that we do not need to specify any code dependencies. -We can run this job using the `run-pyspark` script in `$SPARK_HOME/pyspark`: +We can run this job using the `pyspark` script: {% highlight python %} $ cd $SPARK_HOME -$ ./pyspark/run-pyspark SimpleJob.py +$ ./pyspark SimpleJob.py ... Lines with a: 8422, Lines with b: 1836 {% endhighlight python %} diff --git a/pyspark b/pyspark new file mode 100755 index 0000000000..9e89d51ba2 --- /dev/null +++ b/pyspark @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" + +# Load environment variables from conf/spark-env.sh, if it exists +if [ -e $FWDIR/conf/spark-env.sh ] ; then + . $FWDIR/conf/spark-env.sh +fi + +# Figure out which Python executable to use +if [ -z "$PYSPARK_PYTHON" ] ; then + PYSPARK_PYTHON="python" +fi +export PYSPARK_PYTHON + +# Add the PySpark classes to the Python path: +export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH + +# Load the PySpark shell.py script when ./pyspark is used interactively: +export OLD_PYTHONSTARTUP=$PYTHONSTARTUP +export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py + +# Launch with `scala` by default: +if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then + export SPARK_LAUNCH_WITH_SCALA=1 +fi + +exec "$PYSPARK_PYTHON" "$@" diff --git a/pyspark-shell b/pyspark-shell deleted file mode 100755 index 27aaac3a26..0000000000 --- a/pyspark-shell +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash -FWDIR="`dirname $0`" -exec $FWDIR/run-pyspark $FWDIR/python/pyspark/shell.py "$@" diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index bd39b0283f..7e6ad3aa76 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -1,33 +1,17 @@ """ An interactive shell. -""" -import optparse # I prefer argparse, but it's not included with Python < 2.7 -import code -import sys +This fle is designed to be launched as a PYTHONSTARTUP script. +""" +import os from pyspark.context import SparkContext -def main(master='local', ipython=False): - sc = SparkContext(master, 'PySparkShell') - user_ns = {'sc' : sc} - banner = "Spark context avaiable as sc." - if ipython: - import IPython - IPython.embed(user_ns=user_ns, banner2=banner) - else: - print banner - code.interact(local=user_ns) - +sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell") +print "Spark context avaiable as sc." -if __name__ == '__main__': - usage = "usage: %prog [options] master" - parser = optparse.OptionParser(usage=usage) - parser.add_option("-i", "--ipython", help="Run IPython shell", - action="store_true") - (options, args) = parser.parse_args() - if len(sys.argv) > 1: - master = args[0] - else: - master = 'local' - main(master, options.ipython) +# The ./pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, +# which allows us to execute the user's PYTHONSTARTUP file: +_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') +if _pythonstartup and os.path.isfile(_pythonstartup): + execfile(_pythonstartup) diff --git a/python/run-tests b/python/run-tests new file mode 100755 index 0000000000..da9e24cb1f --- /dev/null +++ b/python/run-tests @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; cd ../; pwd)" + +$FWDIR/pyspark pyspark/rdd.py +$FWDIR/pyspark -m doctest pyspark/broadcast.py + +# TODO: in the long-run, it would be nice to use a test runner like `nose`. diff --git a/run-pyspark b/run-pyspark deleted file mode 100755 index deb0d708b3..0000000000 --- a/run-pyspark +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env bash - -# Figure out where the Scala framework is installed -FWDIR="$(cd `dirname $0`; pwd)" - -# Export this as SPARK_HOME -export SPARK_HOME="$FWDIR" - -# Load environment variables from conf/spark-env.sh, if it exists -if [ -e $FWDIR/conf/spark-env.sh ] ; then - . $FWDIR/conf/spark-env.sh -fi - -# Figure out which Python executable to use -if [ -z "$PYSPARK_PYTHON" ] ; then - PYSPARK_PYTHON="python" -fi -export PYSPARK_PYTHON - -# Add the PySpark classes to the Python path: -export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH - -# Launch with `scala` by default: -if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then - export SPARK_LAUNCH_WITH_SCALA=1 -fi - -exec "$PYSPARK_PYTHON" "$@" -- cgit v1.2.3 From 493d65ce651dffc79adcdada0eeeed6452b3cc47 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 2 Jan 2013 13:35:26 -0800 Subject: Several code-quality improvements to DataHandler. - Changed to more accurate name: BufferingBlockCreator - Docstring now correctly reflects the abstraction offered by the class - Made internal methods private - Fixed indentation problems --- .../spark/streaming/BufferingBlockCreator.scala | 80 +++++++++++++++++++++ .../main/scala/spark/streaming/DataHandler.scala | 83 ---------------------- .../scala/spark/streaming/FlumeInputDStream.scala | 2 +- .../scala/spark/streaming/SocketInputDStream.scala | 2 +- .../spark/streaming/input/KafkaInputDStream.scala | 5 +- 5 files changed, 84 insertions(+), 88 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/BufferingBlockCreator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DataHandler.scala diff --git a/streaming/src/main/scala/spark/streaming/BufferingBlockCreator.scala b/streaming/src/main/scala/spark/streaming/BufferingBlockCreator.scala new file mode 100644 index 0000000000..efd2e75d40 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/BufferingBlockCreator.scala @@ -0,0 +1,80 @@ +package spark.streaming + +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.Logging +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + +/** + * Batches objects created by a [[spark.streaming.NetworkReceiver]] and puts them into + * appropriately named blocks at regular intervals. This class starts two threads, + * one to periodically start a new batch and prepare the previous batch of as a block, + * the other to push the blocks into the block manager. + */ +class BufferingBlockCreator[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) + extends Serializable with Logging { + + case class Block(id: String, iterator: Iterator[T], metadata: Any = null) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + private def createBlock(blockId: String, iterator: Iterator[T]) : Block = { + new Block(blockId, iterator) + } + + private def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) + val newBlock = createBlock(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + private def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala deleted file mode 100644 index 05f307a8d1..0000000000 --- a/streaming/src/main/scala/spark/streaming/DataHandler.scala +++ /dev/null @@ -1,83 +0,0 @@ -package spark.streaming - -import java.util.concurrent.ArrayBlockingQueue -import scala.collection.mutable.ArrayBuffer -import spark.Logging -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - - -/** - * This is a helper object that manages the data received from the socket. It divides - * the object received into small batches of 100s of milliseconds, pushes them as - * blocks into the block manager and reports the block IDs to the network input - * tracker. It starts two threads, one to periodically start a new batch and prepare - * the previous batch of as a block, the other to push the blocks into the block - * manager. - */ - class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) - extends Serializable with Logging { - - case class Block(id: String, iterator: Iterator[T], metadata: Any = null) - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def createBlock(blockId: String, iterator: Iterator[T]) : Block = { - new Block(blockId, iterator) - } - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) - val newBlock = createBlock(blockId, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - receiver.stop() - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - receiver.stop() - } - } - } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala index 2959ce4540..02d9811669 100644 --- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala @@ -110,7 +110,7 @@ class FlumeReceiver( storageLevel: StorageLevel ) extends NetworkReceiver[SparkFlumeEvent](streamId) { - lazy val dataHandler = new DataHandler(this, storageLevel) + lazy val dataHandler = new BufferingBlockCreator(this, storageLevel) protected override def onStart() { val responder = new SpecificResponder( diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index a9e37c0ff0..f7a34d2515 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -32,7 +32,7 @@ class SocketReceiver[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkReceiver[T](streamId) { - lazy protected val dataHandler = new DataHandler(this, storageLevel) + lazy protected val dataHandler = new BufferingBlockCreator(this, storageLevel) override def getLocationPreference = None diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala index 7c642d4802..66f60519bc 100644 --- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -102,7 +102,7 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, val ZK_TIMEOUT = 10000 // Handles pushing data into the BlockManager - lazy protected val dataHandler = new DataHandler(this, storageLevel) + lazy protected val dataHandler = new BufferingBlockCreator(this, storageLevel) // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset lazy val offsets = HashMap[KafkaPartitionKey, Long]() // Connection to Kafka @@ -114,7 +114,6 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, def onStart() { - // Starting the DataHandler that buffers blocks and pushes them into them BlockManager dataHandler.start() // In case we are using multiple Threads to handle Kafka Messages @@ -181,7 +180,7 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, // NOT USED - Originally intended for fault-tolerance // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) - // extends DataHandler[Any](receiver, storageLevel) { + // extends BufferingBlockCreator[Any](receiver, storageLevel) { // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { // // Creates a new Block with Kafka-specific Metadata -- cgit v1.2.3 From 2ef993d159939e9dedf909991ec5c5789bdd3670 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 2 Jan 2013 14:17:59 -0800 Subject: BufferingBlockCreator -> NetworkReceiver.BlockGenerator --- .../spark/streaming/BufferingBlockCreator.scala | 80 ---------------------- .../streaming/dstream/FlumeInputDStream.scala | 10 +-- .../streaming/dstream/KafkaInputDStream.scala | 10 +-- .../streaming/dstream/NetworkInputDStream.scala | 75 ++++++++++++++++++++ .../streaming/dstream/SocketInputDStream.scala | 8 +-- 5 files changed, 89 insertions(+), 94 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/BufferingBlockCreator.scala diff --git a/streaming/src/main/scala/spark/streaming/BufferingBlockCreator.scala b/streaming/src/main/scala/spark/streaming/BufferingBlockCreator.scala deleted file mode 100644 index efd2e75d40..0000000000 --- a/streaming/src/main/scala/spark/streaming/BufferingBlockCreator.scala +++ /dev/null @@ -1,80 +0,0 @@ -package spark.streaming - -import java.util.concurrent.ArrayBlockingQueue -import scala.collection.mutable.ArrayBuffer -import spark.Logging -import spark.streaming.util.{RecurringTimer, SystemClock} -import spark.storage.StorageLevel - -/** - * Batches objects created by a [[spark.streaming.NetworkReceiver]] and puts them into - * appropriately named blocks at regular intervals. This class starts two threads, - * one to periodically start a new batch and prepare the previous batch of as a block, - * the other to push the blocks into the block manager. - */ -class BufferingBlockCreator[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) - extends Serializable with Logging { - - case class Block(id: String, iterator: Iterator[T], metadata: Any = null) - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - private def createBlock(blockId: String, iterator: Iterator[T]) : Block = { - new Block(blockId, iterator) - } - - private def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) - val newBlock = createBlock(blockId, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - receiver.stop() - } - } - - private def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - receiver.stop() - } - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala index a6fa378d6e..ca70e72e56 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala @@ -97,13 +97,13 @@ private[streaming] object SparkFlumeEvent { private[streaming] class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { override def append(event : AvroFlumeEvent) : Status = { - receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) + receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event) Status.OK } override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { events.foreach (event => - receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)) + receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event)) Status.OK } } @@ -118,19 +118,19 @@ class FlumeReceiver( storageLevel: StorageLevel ) extends NetworkReceiver[SparkFlumeEvent](streamId) { - lazy val dataHandler = new BufferingBlockCreator(this, storageLevel) + lazy val blockGenerator = new BlockGenerator(storageLevel) protected override def onStart() { val responder = new SpecificResponder( classOf[AvroSourceProtocol], new FlumeEventServer(this)); val server = new NettyServer(responder, new InetSocketAddress(host, port)); - dataHandler.start() + blockGenerator.start() server.start() logInfo("Flume receiver started") } protected override def onStop() { - dataHandler.stop() + blockGenerator.stop() logInfo("Flume receiver stopped") } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index b1941fb427..25988a2ce7 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -110,19 +110,19 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, val ZK_TIMEOUT = 10000 // Handles pushing data into the BlockManager - lazy protected val dataHandler = new BufferingBlockCreator(this, storageLevel) + lazy protected val blockGenerator = new BlockGenerator(storageLevel) // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset lazy val offsets = HashMap[KafkaPartitionKey, Long]() // Connection to Kafka var consumerConnector : ZookeeperConsumerConnector = null def onStop() { - dataHandler.stop() + blockGenerator.stop() } def onStart() { - dataHandler.start() + blockGenerator.start() // In case we are using multiple Threads to handle Kafka Messages val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) @@ -170,8 +170,8 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, private class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { logInfo("Starting MessageHandler.") - stream.takeWhile { msgAndMetadata => - dataHandler += msgAndMetadata.message + stream.takeWhile { msgAndMetadata => + blockGenerator += msgAndMetadata.message // Updating the offet. The key is (broker, topic, group, partition). val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index 41276da8bb..18e62a0e33 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -14,6 +14,8 @@ import akka.actor.{Props, Actor} import akka.pattern.ask import akka.dispatch.Await import akka.util.duration._ +import spark.streaming.util.{RecurringTimer, SystemClock} +import java.util.concurrent.ArrayBlockingQueue abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { @@ -154,4 +156,77 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri tracker ! DeregisterReceiver(streamId, msg) } } + + /** + * Batches objects created by a [[spark.streaming.NetworkReceiver]] and puts them into + * appropriately named blocks at regular intervals. This class starts two threads, + * one to periodically start a new batch and prepare the previous batch of as a block, + * the other to push the blocks into the block manager. + */ + class BlockGenerator(storageLevel: StorageLevel) + extends Serializable with Logging { + + case class Block(id: String, iterator: Iterator[T], metadata: Any = null) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + private def createBlock(blockId: String, iterator: Iterator[T]) : Block = { + new Block(blockId, iterator) + } + + private def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + NetworkReceiver.this.streamId + "- " + (time - blockInterval) + val newBlock = createBlock(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + NetworkReceiver.this.stop() + } + } + + private def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + NetworkReceiver.this.pushBlock(block.id, block.iterator, block.metadata, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + NetworkReceiver.this.stop() + } + } + } } diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala index 8374f131d6..8e4b20ea4c 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala @@ -29,7 +29,7 @@ class SocketReceiver[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkReceiver[T](streamId) { - lazy protected val dataHandler = new BufferingBlockCreator(this, storageLevel) + lazy protected val blockGenerator = new BlockGenerator(storageLevel) override def getLocationPreference = None @@ -37,16 +37,16 @@ class SocketReceiver[T: ClassManifest]( logInfo("Connecting to " + host + ":" + port) val socket = new Socket(host, port) logInfo("Connected to " + host + ":" + port) - dataHandler.start() + blockGenerator.start() val iterator = bytesToObjects(socket.getInputStream()) while(iterator.hasNext) { val obj = iterator.next - dataHandler += obj + blockGenerator += obj } } protected def onStop() { - dataHandler.stop() + blockGenerator.stop() } } -- cgit v1.2.3 From 33beba39656fc64984db09a82fc69ca4edcc02d4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 3 Jan 2013 14:52:21 -0800 Subject: Change PySpark RDD.take() to not call iterator(). --- core/src/main/scala/spark/api/python/PythonRDD.scala | 4 ++++ python/pyspark/context.py | 1 + python/pyspark/rdd.py | 11 +++++------ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index cf60d14f03..79d824d494 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -10,6 +10,7 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD +import java.util private[spark] class PythonRDD[T: ClassManifest]( @@ -216,6 +217,9 @@ private[spark] object PythonRDD { } file.close() } + + def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = + rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head } private object Pickle { diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6172d69dcf..4439356c1f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,7 @@ class SparkContext(object): jvm = gateway.jvm _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile + _takePartition = jvm.PythonRDD.takePartition def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cbffb6cc1f..4ba417b2a2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -328,18 +328,17 @@ class RDD(object): a lot of partitions are required. In that case, use L{collect} to get the whole RDD instead. - >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] """ items = [] - splits = self._jrdd.splits() - taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) - while len(items) < num and splits: - split = splits.pop(0) - iterator = self._jrdd.iterator(split, taskContext) + for partition in range(self._jrdd.splits().size()): + iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) items.extend(self._collect_iterator_through_file(iterator)) + if len(items) >= num: + break return items[:num] def first(self): -- cgit v1.2.3 From 8d57c78c83f74e45ce3c119e2e3915d5eac264e7 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 5 Jan 2013 10:54:05 -0600 Subject: Add PairRDDFunctions.keys and values. --- core/src/main/scala/spark/PairRDDFunctions.scala | 10 ++++++++++ core/src/test/scala/spark/ShuffleSuite.scala | 7 +++++++ 2 files changed, 17 insertions(+) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 413c944a66..ce48cea903 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -615,6 +615,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( writer.cleanup() } + /** + * Return an RDD with the keys of each tuple. + */ + def keys: RDD[K] = self.map(_._1) + + /** + * Return an RDD with the values of each tuple. + */ + def values: RDD[V] = self.map(_._2) + private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 8170100f1d..5a867016f2 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -216,6 +216,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { // Test that a shuffle on the file works, because this used to be a bug assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) } + + test("kesy and values") { + sc = new SparkContext("local", "test") + val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) + assert(rdd.keys.collect().toList === List(1, 2)) + assert(rdd.values.collect().toList === List("a", "b")) + } } object ShuffleSuite { -- cgit v1.2.3 From f4e6b9361ffeec1018d5834f09db9fd86f2ba7bd Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 4 Jan 2013 22:43:22 -0600 Subject: Add RDD.collect(PartialFunction). --- core/src/main/scala/spark/RDD.scala | 7 +++++++ core/src/test/scala/spark/RDDSuite.scala | 1 + 2 files changed, 8 insertions(+) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7e38583391..5163c80134 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -329,6 +329,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial */ def toArray(): Array[T] = collect() + /** + * Return an RDD that contains all matching values by applying `f`. + */ + def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { + filter(f.isDefinedAt).map(f) + } + /** * Reduces the elements of this RDD using the specified associative binary operator. */ diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 45e6c5f840..872b06fd08 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,6 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) -- cgit v1.2.3 From 6a0db3b449a829f3e5cdf7229f6ee564268be1df Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 5 Jan 2013 12:56:17 -0600 Subject: Fix typo. --- core/src/test/scala/spark/ShuffleSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 5a867016f2..bebb8ebe86 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -217,7 +217,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) } - test("kesy and values") { + test("keys and values") { sc = new SparkContext("local", "test") val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) assert(rdd.keys.collect().toList === List(1, 2)) -- cgit v1.2.3 From 1fdb6946b5d076ed0f1b4d2bca2a20b6cd22cbc3 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 5 Jan 2013 13:07:59 -0600 Subject: Add RDD.tupleBy. --- core/src/main/scala/spark/RDD.scala | 7 +++++++ core/src/test/scala/spark/RDDSuite.scala | 1 + 2 files changed, 8 insertions(+) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7e38583391..7aa4b0a173 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -510,6 +510,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial .saveAsSequenceFile(path) } + /** + * Tuples the elements of this RDD by applying `f`. + */ + def tupleBy[K](f: T => K): RDD[(K, T)] = { + map(x => (f(x), x)) + } + /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 45e6c5f840..7832884224 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,6 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.tupleBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) -- cgit v1.2.3 From ecf9c0890160c69f1b64b36fa8fdea2f6dd973eb Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 20:54:08 -0500 Subject: Fix Accumulators in Java, and add a test for them --- core/src/main/scala/spark/Accumulators.scala | 18 ++++++++- core/src/main/scala/spark/SparkContext.scala | 7 ++-- .../scala/spark/api/java/JavaSparkContext.scala | 23 +++++++---- core/src/test/scala/spark/JavaAPISuite.java | 44 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bacd0ace37..6280f25391 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -38,14 +38,28 @@ class Accumulable[R, T] ( */ def += (term: T) { value_ = param.addAccumulator(value_, term) } + /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def add(term: T) { value_ = param.addAccumulator(value_, term) } + /** * Merge two accumulable objects together - * + * * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other Accumulable that will get merged with this + * @param term the other `R` that will get merged with this */ def ++= (term: R) { value_ = param.addInPlace(value_, term)} + /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `add`. + * @param term the other `R` that will get merged with this + */ + def merge(term: R) { value_ = param.addInPlace(value_, term)} + /** * Access the accumulator's current value; only allowed on master. */ diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4fd81bc63b..bbf8272eb3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -382,11 +382,12 @@ class SparkContext( new Accumulator(initialValue, param) /** - * Create an [[spark.Accumulable]] shared variable, with a `+=` method + * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. + * Only the master can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ - def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = + def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) /** @@ -404,7 +405,7 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded into the working directory of this Spark job on every node. diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index b7725313c4..bf9ad7a200 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} +import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} import spark.SparkContext.IntAccumulatorParam import spark.SparkContext.DoubleAccumulatorParam import spark.broadcast.Broadcast @@ -265,25 +265,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def intAccumulator(initialValue: Int): Accumulator[Int] = - sc.accumulator(initialValue)(IntAccumulatorParam) + def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] /** * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def doubleAccumulator(initialValue: Double): Accumulator[Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam) + def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = + sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can + * "add" values with `add`. Only the master can access the accumuable's `value`. + */ + def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = + sc.accumulable(initialValue)(param) + /** * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 33d5fc2d89..b99e790093 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable { JavaPairRDD zipped = rdd.zip(doubles); zipped.count(); } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + intAccum.add(x); + } + }); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + doubleAccum.add((double) x); + } + }); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + floatAccum.add((float) x); + } + }); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + } } -- cgit v1.2.3 From 86af64b0a6fde5a6418727a77b43bdfeda1b81cd Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 20:54:08 -0500 Subject: Fix Accumulators in Java, and add a test for them --- core/src/main/scala/spark/Accumulators.scala | 18 ++++++++- core/src/main/scala/spark/SparkContext.scala | 7 ++-- .../scala/spark/api/java/JavaSparkContext.scala | 23 +++++++---- core/src/test/scala/spark/JavaAPISuite.java | 44 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bacd0ace37..6280f25391 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -38,14 +38,28 @@ class Accumulable[R, T] ( */ def += (term: T) { value_ = param.addAccumulator(value_, term) } + /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def add(term: T) { value_ = param.addAccumulator(value_, term) } + /** * Merge two accumulable objects together - * + * * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other Accumulable that will get merged with this + * @param term the other `R` that will get merged with this */ def ++= (term: R) { value_ = param.addInPlace(value_, term)} + /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `add`. + * @param term the other `R` that will get merged with this + */ + def merge(term: R) { value_ = param.addInPlace(value_, term)} + /** * Access the accumulator's current value; only allowed on master. */ diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4fd81bc63b..bbf8272eb3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -382,11 +382,12 @@ class SparkContext( new Accumulator(initialValue, param) /** - * Create an [[spark.Accumulable]] shared variable, with a `+=` method + * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. + * Only the master can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ - def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = + def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) /** @@ -404,7 +405,7 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded into the working directory of this Spark job on every node. diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index b7725313c4..bf9ad7a200 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} +import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} import spark.SparkContext.IntAccumulatorParam import spark.SparkContext.DoubleAccumulatorParam import spark.broadcast.Broadcast @@ -265,25 +265,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def intAccumulator(initialValue: Int): Accumulator[Int] = - sc.accumulator(initialValue)(IntAccumulatorParam) + def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] /** * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def doubleAccumulator(initialValue: Double): Accumulator[Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam) + def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = + sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can + * "add" values with `add`. Only the master can access the accumuable's `value`. + */ + def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = + sc.accumulable(initialValue)(param) + /** * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 33d5fc2d89..b99e790093 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable { JavaPairRDD zipped = rdd.zip(doubles); zipped.count(); } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + intAccum.add(x); + } + }); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + doubleAccum.add((double) x); + } + }); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + floatAccum.add((float) x); + } + }); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + } } -- cgit v1.2.3 From 0982572519655354b10987de4f68e29b8331bd2a Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 22:11:28 -0500 Subject: Add methods called just 'accumulator' for int/double in Java API --- core/src/main/scala/spark/api/java/JavaSparkContext.scala | 13 +++++++++++++ core/src/test/scala/spark/JavaAPISuite.java | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index bf9ad7a200..88ab2846be 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -277,6 +277,19 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] + /** + * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) + + /** + * Create an [[spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Double): Accumulator[java.lang.Double] = + doubleAccumulator(initialValue) + /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index b99e790093..912f8de05d 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -586,7 +586,7 @@ public class JavaAPISuite implements Serializable { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); + final Accumulator intAccum = sc.accumulator(10); rdd.foreach(new VoidFunction() { public void call(Integer x) { intAccum.add(x); @@ -594,7 +594,7 @@ public class JavaAPISuite implements Serializable { }); Assert.assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + final Accumulator doubleAccum = sc.accumulator(10.0); rdd.foreach(new VoidFunction() { public void call(Integer x) { doubleAccum.add((double) x); -- cgit v1.2.3 From 8fd3a70c188182105f81f5143ec65e74663582d5 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 22:46:45 -0500 Subject: Add PairRDD.keys() and values() to Java API --- core/src/main/scala/spark/api/java/JavaPairRDD.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 5c2be534ff..8ce32e0e2f 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -471,6 +471,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending)) } + + /** + * Return an RDD with the keys of each tuple. + */ + def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1)) + + /** + * Return an RDD with the values of each tuple. + */ + def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2)) } object JavaPairRDD { -- cgit v1.2.3 From 8dc06069fe2330c3ee0fcaaeb0ae6e627a5887c3 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sun, 6 Jan 2013 15:21:45 -0600 Subject: Rename RDD.tupleBy to keyBy. --- core/src/main/scala/spark/RDD.scala | 4 ++-- core/src/test/scala/spark/RDDSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7aa4b0a173..5ce524c0e7 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -511,9 +511,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial } /** - * Tuples the elements of this RDD by applying `f`. + * Creates tuples of the elements in this RDD by applying `f`. */ - def tupleBy[K](f: T => K): RDD[(K, T)] = { + def keyBy[K](f: T => K): RDD[(K, T)] = { map(x => (f(x), x)) } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 7832884224..77bff8aba1 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,7 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) - assert(nums.tupleBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) + assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) -- cgit v1.2.3 From 934ecc829aa06ce4d9ded3596b86b4733ed2a123 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 6 Jan 2013 14:15:07 -0800 Subject: Removed streaming-env.sh.template --- conf/streaming-env.sh.template | 22 ---------------------- run | 4 ---- sentences.txt | 3 --- 3 files changed, 29 deletions(-) delete mode 100755 conf/streaming-env.sh.template delete mode 100644 sentences.txt diff --git a/conf/streaming-env.sh.template b/conf/streaming-env.sh.template deleted file mode 100755 index 1ea9ba5541..0000000000 --- a/conf/streaming-env.sh.template +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env bash - -# This file contains a few additional setting that are useful for -# running streaming jobs in Spark. Copy this file as streaming-env.sh . -# Note that this shell script will be read after spark-env.sh, so settings -# in this file may override similar settings (if present) in spark-env.sh . - - -# Using concurrent GC is strongly recommended as it can significantly -# reduce GC related pauses. - -SPARK_JAVA_OPTS+=" -XX:+UseConcMarkSweepGC" - -# Using Kryo serialization can improve serialization performance -# and therefore the throughput of the Spark Streaming programs. However, -# using Kryo serialization with custom classes may required you to -# register the classes with Kryo. Refer to the Spark documentation -# for more details. - -# SPARK_JAVA_OPTS+=" -Dspark.serializer=spark.KryoSerializer" - -export SPARK_JAVA_OPTS diff --git a/run b/run index 27506c33e2..2f61cb2a87 100755 --- a/run +++ b/run @@ -13,10 +13,6 @@ if [ -e $FWDIR/conf/spark-env.sh ] ; then . $FWDIR/conf/spark-env.sh fi -if [ -e $FWDIR/conf/streaming-env.sh ] ; then - . $FWDIR/conf/streaming-env.sh -fi - if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then if [ `command -v scala` ]; then RUNNER="scala" diff --git a/sentences.txt b/sentences.txt deleted file mode 100644 index fedf96c66e..0000000000 --- a/sentences.txt +++ /dev/null @@ -1,3 +0,0 @@ -Hello world! -What's up? -There is no cow level -- cgit v1.2.3 From af8738dfb592eb37d4d6c91e42624e844d4e493b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 6 Jan 2013 19:31:54 -0800 Subject: Moved Spark Streaming examples to examples sub-project. --- .../spark/streaming/examples/FileStream.scala | 46 ++++++++++++ .../examples/FileStreamWithCheckpoint.scala | 75 +++++++++++++++++++ .../spark/streaming/examples/FlumeEventCount.scala | 43 +++++++++++ .../scala/spark/streaming/examples/GrepRaw.scala | 32 ++++++++ .../spark/streaming/examples/KafkaWordCount.scala | 69 ++++++++++++++++++ .../spark/streaming/examples/QueueStream.scala | 39 ++++++++++ .../streaming/examples/TopKWordCountRaw.scala | 49 +++++++++++++ .../spark/streaming/examples/WordCountHdfs.scala | 25 +++++++ .../streaming/examples/WordCountNetwork.scala | 25 +++++++ .../spark/streaming/examples/WordCountRaw.scala | 43 +++++++++++ .../examples/clickstream/PageViewGenerator.scala | 85 ++++++++++++++++++++++ .../examples/clickstream/PageViewStream.scala | 84 +++++++++++++++++++++ project/SparkBuild.scala | 2 +- .../spark/streaming/examples/FileStream.scala | 46 ------------ .../examples/FileStreamWithCheckpoint.scala | 75 ------------------- .../spark/streaming/examples/FlumeEventCount.scala | 43 ----------- .../scala/spark/streaming/examples/GrepRaw.scala | 33 --------- .../spark/streaming/examples/KafkaWordCount.scala | 69 ------------------ .../spark/streaming/examples/QueueStream.scala | 39 ---------- .../streaming/examples/TopKWordCountRaw.scala | 49 ------------- .../spark/streaming/examples/WordCountHdfs.scala | 25 ------- .../streaming/examples/WordCountNetwork.scala | 25 ------- .../spark/streaming/examples/WordCountRaw.scala | 43 ----------- .../examples/clickstream/PageViewGenerator.scala | 85 ---------------------- .../examples/clickstream/PageViewStream.scala | 84 --------------------- 25 files changed, 616 insertions(+), 617 deletions(-) create mode 100644 examples/src/main/scala/spark/streaming/examples/FileStream.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/GrepRaw.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/QueueStream.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/FileStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/QueueStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala diff --git a/examples/src/main/scala/spark/streaming/examples/FileStream.scala b/examples/src/main/scala/spark/streaming/examples/FileStream.scala new file mode 100644 index 0000000000..81938d30d4 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/FileStream.scala @@ -0,0 +1,46 @@ +package spark.streaming.examples + +import spark.streaming.StreamingContext +import spark.streaming.StreamingContext._ +import spark.streaming.Seconds +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + + +object FileStream { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: FileStream ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "FileStream", Seconds(1)) + + // Create the new directory + val directory = new Path(args(1)) + val fs = directory.getFileSystem(new Configuration()) + if (fs.exists(directory)) throw new Exception("This directory already exists") + fs.mkdirs(directory) + fs.deleteOnExit(directory) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val inputStream = ssc.textFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + + // Creating new files in the directory + val text = "This is a text file" + for (i <- 1 to 30) { + ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) + .saveAsTextFile(new Path(directory, i.toString).toString) + Thread.sleep(1000) + } + Thread.sleep(5000) // Waiting for the file to be processed + ssc.stop() + System.exit(0) + } +} \ No newline at end of file diff --git a/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala new file mode 100644 index 0000000000..b7bc15a1d5 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -0,0 +1,75 @@ +package spark.streaming.examples + +import spark.streaming._ +import spark.streaming.StreamingContext._ +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration + +object FileStreamWithCheckpoint { + + def main(args: Array[String]) { + + if (args.size != 3) { + println("FileStreamWithCheckpoint ") + println("FileStreamWithCheckpoint restart ") + System.exit(-1) + } + + val directory = new Path(args(1)) + val checkpointDir = args(2) + + val ssc: StreamingContext = { + + if (args(0) == "restart") { + + // Recreated streaming context from specified checkpoint file + new StreamingContext(checkpointDir) + + } else { + + // Create directory if it does not exist + val fs = directory.getFileSystem(new Configuration()) + if (!fs.exists(directory)) fs.mkdirs(directory) + + // Create new streaming context + val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint", Seconds(1)) + ssc_.checkpoint(checkpointDir) + + // Setup the streaming computation + val inputStream = ssc_.textFileStream(directory.toString) + val words = inputStream.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + + ssc_ + } + } + + // Start the stream computation + startFileWritingThread(directory.toString) + ssc.start() + } + + def startFileWritingThread(directory: String) { + + val fs = new Path(directory).getFileSystem(new Configuration()) + + val fileWritingThread = new Thread() { + override def run() { + val r = new scala.util.Random() + val text = "This is a sample text file with a random number " + while(true) { + val number = r.nextInt() + val file = new Path(directory, number.toString) + val fos = fs.create(file) + fos.writeChars(text + number) + fos.close() + println("Created text file " + file) + Thread.sleep(1000) + } + } + } + fileWritingThread.start() + } + +} diff --git a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala new file mode 100644 index 0000000000..e60ce483a3 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala @@ -0,0 +1,43 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ + +/** + * Produce a streaming count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: FlumeEventCount + * + * is a Spark master URL + * is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * is the port the Flume receiver will listen on. + */ +object FlumeEventCount { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println( + "Usage: FlumeEventCount ") + System.exit(1) + } + + val Array(master, host, IntParam(port)) = args + + val batchInterval = Milliseconds(2000) + // Create the context and set the batch size + val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval) + + // Create a flume stream + val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) + + // Print out the count of events received from this server in each batch + stream.count().map(cnt => "Received " + cnt + " flume events." ).print() + + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala b/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala new file mode 100644 index 0000000000..812faa368a --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala @@ -0,0 +1,32 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel + +import spark.streaming._ +import spark.streaming.util.RawTextHelper._ + +object GrepRaw { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: GrepRaw ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context + val ssc = new StreamingContext(master, "GrepRaw", Milliseconds(batchMillis)) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + warmUp(ssc.sc) + + + val rawStreams = (1 to numStreams).map(_ => + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray + val union = ssc.union(rawStreams) + union.filter(_.contains("Alice")).count().foreach(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala new file mode 100644 index 0000000000..fe55db6e2c --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -0,0 +1,69 @@ +package spark.streaming.examples + +import java.util.Properties +import kafka.message.Message +import kafka.producer.SyncProducerConfig +import kafka.producer._ +import spark.SparkContext +import spark.streaming._ +import spark.streaming.StreamingContext._ +import spark.storage.StorageLevel +import spark.streaming.util.RawTextHelper._ + +object KafkaWordCount { + def main(args: Array[String]) { + + if (args.length < 6) { + System.err.println("Usage: KafkaWordCount ") + System.exit(1) + } + + val Array(master, hostname, port, group, topics, numThreads) = args + + val sc = new SparkContext(master, "KafkaWordCount") + val ssc = new StreamingContext(sc, Seconds(2)) + ssc.checkpoint("checkpoint") + + val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap + val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) + wordCounts.print() + + ssc.start() + } +} + +// Produces some random words between 1 and 100. +object KafkaWordCountProducer { + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: KafkaWordCountProducer ") + System.exit(1) + } + + val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args + + // Zookeper connection properties + val props = new Properties() + props.put("zk.connect", hostname + ":" + port) + props.put("serializer.class", "kafka.serializer.StringEncoder") + + val config = new ProducerConfig(props) + val producer = new Producer[String, String](config) + + // Send some messages + while(true) { + val messages = (1 to messagesPerSec.toInt).map { messageNum => + (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ") + }.toArray + println(messages.mkString(",")) + val data = new ProducerData[String, String](topic, messages) + producer.send(data) + Thread.sleep(100) + } + } + +} + diff --git a/examples/src/main/scala/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/spark/streaming/examples/QueueStream.scala new file mode 100644 index 0000000000..2a265d021d --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/QueueStream.scala @@ -0,0 +1,39 @@ +package spark.streaming.examples + +import spark.RDD +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +import scala.collection.mutable.SynchronizedQueue + +object QueueStream { + + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: QueueStream ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1)) + + // Create the queue through which RDDs can be pushed to + // a QueueInputDStream + val rddQueue = new SynchronizedQueue[RDD[Int]]() + + // Create the QueueInputDStream and use it do some processing + val inputStream = ssc.queueStream(rddQueue) + val mappedStream = inputStream.map(x => (x % 10, 1)) + val reducedStream = mappedStream.reduceByKey(_ + _) + reducedStream.print() + ssc.start() + + // Create and push some RDDs into + for (i <- 1 to 30) { + rddQueue += ssc.sc.makeRDD(1 to 1000, 10) + Thread.sleep(1000) + } + ssc.stop() + System.exit(0) + } +} \ No newline at end of file diff --git a/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala new file mode 100644 index 0000000000..338834bc3c --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -0,0 +1,49 @@ +package spark.streaming.examples + +import spark.storage.StorageLevel +import spark.util.IntParam + +import spark.streaming._ +import spark.streaming.StreamingContext._ +import spark.streaming.util.RawTextHelper._ + +import java.util.UUID + +object TopKWordCountRaw { + + def main(args: Array[String]) { + if (args.length != 4) { + System.err.println("Usage: WordCountRaw <# streams> ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args + val k = 10 + + // Create the context, and set the checkpoint directory. + // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts + // periodically to HDFS + val ssc = new StreamingContext(master, "TopKWordCountRaw", Seconds(1)) + ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + /*warmUp(ssc.sc)*/ + + // Set up the raw network streams that will connect to localhost:port to raw test + // senders on the slaves and generate top K words of last 30 seconds + val lines = (1 to numStreams).map(_ => { + ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) + }) + val union = ssc.union(lines) + val counts = union.mapPartitions(splitAndCountPartitions) + val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) + val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) + partialTopKWindowedCounts.foreach(rdd => { + val collectedCounts = rdd.collect + println("Collected " + collectedCounts.size + " words from partial top words") + println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(",")) + }) + + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala new file mode 100644 index 0000000000..867a8f42c4 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala @@ -0,0 +1,25 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +object WordCountHdfs { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: WordCountHdfs ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "WordCountHdfs", Seconds(2)) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val lines = ssc.textFileStream(args(1)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} + diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala new file mode 100644 index 0000000000..eadda60563 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala @@ -0,0 +1,25 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +object WordCountNetwork { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: WordCountNetwork \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.networkTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala new file mode 100644 index 0000000000..d93335a8ce --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -0,0 +1,43 @@ +package spark.streaming.examples + +import spark.storage.StorageLevel +import spark.util.IntParam + +import spark.streaming._ +import spark.streaming.StreamingContext._ +import spark.streaming.util.RawTextHelper._ + +import java.util.UUID + +object WordCountRaw { + + def main(args: Array[String]) { + if (args.length != 4) { + System.err.println("Usage: WordCountRaw <# streams> ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args + + // Create the context, and set the checkpoint directory. + // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts + // periodically to HDFS + val ssc = new StreamingContext(master, "WordCountRaw", Seconds(1)) + ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + warmUp(ssc.sc) + + // Set up the raw network streams that will connect to localhost:port to raw test + // senders on the slaves and generate count of words of last 30 seconds + val lines = (1 to numStreams).map(_ => { + ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) + }) + val union = ssc.union(lines) + val counts = union.mapPartitions(splitAndCountPartitions) + val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) + windowedCounts.foreach(r => println("# unique words = " + r.count())) + + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala new file mode 100644 index 0000000000..4c6e08bc74 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala @@ -0,0 +1,85 @@ +package spark.streaming.examples.clickstream + +import java.net.{InetAddress,ServerSocket,Socket,SocketException} +import java.io.{InputStreamReader, BufferedReader, PrintWriter} +import util.Random + +/** Represents a page view on a website with associated dimension data.*/ +class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) { + override def toString() : String = { + "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) + } +} +object PageView { + def fromString(in : String) : PageView = { + val parts = in.split("\t") + new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) + } +} + +/** Generates streaming events to simulate page views on a website. + * + * This should be used in tandem with PageViewStream.scala. Example: + * $ ./run spark.streaming.examples.clickstream.PageViewGenerator 44444 10 + * $ ./run spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 + * */ +object PageViewGenerator { + val pages = Map("http://foo.com/" -> .7, + "http://foo.com/news" -> 0.2, + "http://foo.com/contact" -> .1) + val httpStatus = Map(200 -> .95, + 404 -> .05) + val userZipCode = Map(94709 -> .5, + 94117 -> .5) + val userID = Map((1 to 100).map(_ -> .01):_*) + + + def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { + val rand = new Random().nextDouble() + var total = 0.0 + for ((item, prob) <- inputMap) { + total = total + prob + if (total > rand) { + return item + } + } + return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 + } + + def getNextClickEvent() : String = { + val id = pickFromDistribution(userID) + val page = pickFromDistribution(pages) + val status = pickFromDistribution(httpStatus) + val zipCode = pickFromDistribution(userZipCode) + new PageView(page, status, zipCode, id).toString() + } + + def main(args : Array[String]) { + if (args.length != 2) { + System.err.println("Usage: PageViewGenerator ") + System.exit(1) + } + val port = args(0).toInt + val viewsPerSecond = args(1).toFloat + val sleepDelayMs = (1000.0 / viewsPerSecond).toInt + val listener = new ServerSocket(port) + println("Listening on port: " + port) + + while (true) { + val socket = listener.accept() + new Thread() { + override def run = { + println("Got client connected from: " + socket.getInetAddress) + val out = new PrintWriter(socket.getOutputStream(), true) + + while (true) { + Thread.sleep(sleepDelayMs) + out.write(getNextClickEvent()) + out.flush() + } + socket.close() + } + }.start() + } + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala new file mode 100644 index 0000000000..a191321d91 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala @@ -0,0 +1,84 @@ +package spark.streaming.examples.clickstream + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ +import spark.SparkContext._ + +/** Analyses a streaming dataset of web page views. This class demonstrates several types of + * operators available in Spark streaming. + * + * This should be used in tandem with PageViewStream.scala. Example: + * $ ./run spark.streaming.examples.clickstream.PageViewGenerator 44444 10 + * $ ./run spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 + * */ +object PageViewStream { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println("Usage: PageViewStream ") + System.err.println(" must be one of pageCounts, slidingPageCounts," + + " errorRatePerZipCode, activeUserCount, popularUsersSeen") + System.exit(1) + } + val metric = args(0) + val host = args(1) + val port = args(2).toInt + + // Create the context + val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1)) + + // Create a NetworkInputDStream on target host:port and convert each line to a PageView + val pageViews = ssc.networkTextStream(host, port) + .flatMap(_.split("\n")) + .map(PageView.fromString(_)) + + // Return a count of views per URL seen in each batch + val pageCounts = pageViews.map(view => ((view.url, 1))).countByKey() + + // Return a sliding window of page views per URL in the last ten seconds + val slidingPageCounts = pageViews.map(view => ((view.url, 1))) + .window(Seconds(10), Seconds(2)) + .countByKey() + + + // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds + val statusesPerZipCode = pageViews.window(Seconds(30), Seconds(2)) + .map(view => ((view.zipCode, view.status))) + .groupByKey() + val errorRatePerZipCode = statusesPerZipCode.map{ + case(zip, statuses) => + val normalCount = statuses.filter(_ == 200).size + val errorCount = statuses.size - normalCount + val errorRatio = errorCount.toFloat / statuses.size + if (errorRatio > 0.05) {"%s: **%s**".format(zip, errorRatio)} + else {"%s: %s".format(zip, errorRatio)} + } + + // Return the number unique users in last 15 seconds + val activeUserCount = pageViews.window(Seconds(15), Seconds(2)) + .map(view => (view.userID, 1)) + .groupByKey() + .count() + .map("Unique active users: " + _) + + // An external dataset we want to join to this stream + val userList = ssc.sc.parallelize( + Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) + + metric match { + case "pageCounts" => pageCounts.print() + case "slidingPageCounts" => slidingPageCounts.print() + case "errorRatePerZipCode" => errorRatePerZipCode.print() + case "activeUserCount" => activeUserCount.print() + case "popularUsersSeen" => + // Look for users in our existing dataset and print it out if we have a match + pageViews.map(view => (view.userID, 1)) + .foreach((rdd, time) => rdd.join(userList) + .map(_._2._2) + .take(10) + .foreach(u => println("Saw user %s at time %s".format(u, time)))) + case _ => println("Invalid metric entered: " + metric) + } + + ssc.start() + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a3f901a081..6ba3026bcc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -23,7 +23,7 @@ object SparkBuild extends Build { lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) - lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) + lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming) lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn (core) diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala deleted file mode 100644 index 81938d30d4..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.StreamingContext -import spark.streaming.StreamingContext._ -import spark.streaming.Seconds -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - - -object FileStream { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: FileStream ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "FileStream", Seconds(1)) - - // Create the new directory - val directory = new Path(args(1)) - val fs = directory.getFileSystem(new Configuration()) - if (fs.exists(directory)) throw new Exception("This directory already exists") - fs.mkdirs(directory) - fs.deleteOnExit(directory) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val inputStream = ssc.textFileStream(directory.toString) - val words = inputStream.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - - // Creating new files in the directory - val text = "This is a text file" - for (i <- 1 to 30) { - ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) - .saveAsTextFile(new Path(directory, i.toString).toString) - Thread.sleep(1000) - } - Thread.sleep(5000) // Waiting for the file to be processed - ssc.stop() - System.exit(0) - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala deleted file mode 100644 index b7bc15a1d5..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ /dev/null @@ -1,75 +0,0 @@ -package spark.streaming.examples - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - -object FileStreamWithCheckpoint { - - def main(args: Array[String]) { - - if (args.size != 3) { - println("FileStreamWithCheckpoint ") - println("FileStreamWithCheckpoint restart ") - System.exit(-1) - } - - val directory = new Path(args(1)) - val checkpointDir = args(2) - - val ssc: StreamingContext = { - - if (args(0) == "restart") { - - // Recreated streaming context from specified checkpoint file - new StreamingContext(checkpointDir) - - } else { - - // Create directory if it does not exist - val fs = directory.getFileSystem(new Configuration()) - if (!fs.exists(directory)) fs.mkdirs(directory) - - // Create new streaming context - val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint", Seconds(1)) - ssc_.checkpoint(checkpointDir) - - // Setup the streaming computation - val inputStream = ssc_.textFileStream(directory.toString) - val words = inputStream.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - - ssc_ - } - } - - // Start the stream computation - startFileWritingThread(directory.toString) - ssc.start() - } - - def startFileWritingThread(directory: String) { - - val fs = new Path(directory).getFileSystem(new Configuration()) - - val fileWritingThread = new Thread() { - override def run() { - val r = new scala.util.Random() - val text = "This is a sample text file with a random number " - while(true) { - val number = r.nextInt() - val file = new Path(directory, number.toString) - val fos = fs.create(file) - fos.writeChars(text + number) - fos.close() - println("Created text file " + file) - Thread.sleep(1000) - } - } - } - fileWritingThread.start() - } - -} diff --git a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala deleted file mode 100644 index e60ce483a3..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala +++ /dev/null @@ -1,43 +0,0 @@ -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel -import spark.streaming._ - -/** - * Produce a streaming count of events received from Flume. - * - * This should be used in conjunction with an AvroSink in Flume. It will start - * an Avro server on at the request host:port address and listen for requests. - * Your Flume AvroSink should be pointed to this address. - * - * Usage: FlumeEventCount - * - * is a Spark master URL - * is the host the Flume receiver will be started on - a receiver - * creates a server and listens for flume events. - * is the port the Flume receiver will listen on. - */ -object FlumeEventCount { - def main(args: Array[String]) { - if (args.length != 3) { - System.err.println( - "Usage: FlumeEventCount ") - System.exit(1) - } - - val Array(master, host, IntParam(port)) = args - - val batchInterval = Milliseconds(2000) - // Create the context and set the batch size - val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval) - - // Create a flume stream - val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) - - // Print out the count of events received from this server in each batch - stream.count().map(cnt => "Received " + cnt + " flume events." ).print() - - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala deleted file mode 100644 index dfaaf03f03..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ /dev/null @@ -1,33 +0,0 @@ -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.streaming.util.RawTextHelper._ - -object GrepRaw { - def main(args: Array[String]) { - if (args.length != 5) { - System.err.println("Usage: GrepRaw ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - - // Create the context - val ssc = new StreamingContext(master, "GrepRaw", Milliseconds(batchMillis)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - warmUp(ssc.sc) - - - val rawStreams = (1 to numStreams).map(_ => - ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = ssc.union(rawStreams) - union.filter(_.contains("Alice")).count().foreach(r => - println("Grep count: " + r.collect().mkString)) - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala deleted file mode 100644 index fe55db6e2c..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ /dev/null @@ -1,69 +0,0 @@ -package spark.streaming.examples - -import java.util.Properties -import kafka.message.Message -import kafka.producer.SyncProducerConfig -import kafka.producer._ -import spark.SparkContext -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.storage.StorageLevel -import spark.streaming.util.RawTextHelper._ - -object KafkaWordCount { - def main(args: Array[String]) { - - if (args.length < 6) { - System.err.println("Usage: KafkaWordCount ") - System.exit(1) - } - - val Array(master, hostname, port, group, topics, numThreads) = args - - val sc = new SparkContext(master, "KafkaWordCount") - val ssc = new StreamingContext(sc, Seconds(2)) - ssc.checkpoint("checkpoint") - - val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap - val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) - wordCounts.print() - - ssc.start() - } -} - -// Produces some random words between 1 and 100. -object KafkaWordCountProducer { - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: KafkaWordCountProducer ") - System.exit(1) - } - - val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args - - // Zookeper connection properties - val props = new Properties() - props.put("zk.connect", hostname + ":" + port) - props.put("serializer.class", "kafka.serializer.StringEncoder") - - val config = new ProducerConfig(props) - val producer = new Producer[String, String](config) - - // Send some messages - while(true) { - val messages = (1 to messagesPerSec.toInt).map { messageNum => - (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ") - }.toArray - println(messages.mkString(",")) - val data = new ProducerData[String, String](topic, messages) - producer.send(data) - Thread.sleep(100) - } - } - -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala b/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala deleted file mode 100644 index 2a265d021d..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark.streaming.examples - -import spark.RDD -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -import scala.collection.mutable.SynchronizedQueue - -object QueueStream { - - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: QueueStream ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1)) - - // Create the queue through which RDDs can be pushed to - // a QueueInputDStream - val rddQueue = new SynchronizedQueue[RDD[Int]]() - - // Create the QueueInputDStream and use it do some processing - val inputStream = ssc.queueStream(rddQueue) - val mappedStream = inputStream.map(x => (x % 10, 1)) - val reducedStream = mappedStream.reduceByKey(_ + _) - reducedStream.print() - ssc.start() - - // Create and push some RDDs into - for (i <- 1 to 30) { - rddQueue += ssc.sc.makeRDD(1 to 1000, 10) - Thread.sleep(1000) - } - ssc.stop() - System.exit(0) - } -} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala deleted file mode 100644 index 338834bc3c..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.streaming.examples - -import spark.storage.StorageLevel -import spark.util.IntParam - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.streaming.util.RawTextHelper._ - -import java.util.UUID - -object TopKWordCountRaw { - - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: WordCountRaw <# streams> ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - val k = 10 - - // Create the context, and set the checkpoint directory. - // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts - // periodically to HDFS - val ssc = new StreamingContext(master, "TopKWordCountRaw", Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - /*warmUp(ssc.sc)*/ - - // Set up the raw network streams that will connect to localhost:port to raw test - // senders on the slaves and generate top K words of last 30 seconds - val lines = (1 to numStreams).map(_ => { - ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) - }) - val union = ssc.union(lines) - val counts = union.mapPartitions(splitAndCountPartitions) - val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreach(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " words from partial top words") - println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(",")) - }) - - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala deleted file mode 100644 index 867a8f42c4..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -object WordCountHdfs { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCountHdfs ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "WordCountHdfs", Seconds(2)) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val lines = ssc.textFileStream(args(1)) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} - diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala deleted file mode 100644 index eadda60563..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -object WordCountNetwork { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCountNetwork \n" + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.networkTextStream(args(1), args(2).toInt) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala deleted file mode 100644 index d93335a8ce..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ /dev/null @@ -1,43 +0,0 @@ -package spark.streaming.examples - -import spark.storage.StorageLevel -import spark.util.IntParam - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.streaming.util.RawTextHelper._ - -import java.util.UUID - -object WordCountRaw { - - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: WordCountRaw <# streams> ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - - // Create the context, and set the checkpoint directory. - // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts - // periodically to HDFS - val ssc = new StreamingContext(master, "WordCountRaw", Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - warmUp(ssc.sc) - - // Set up the raw network streams that will connect to localhost:port to raw test - // senders on the slaves and generate count of words of last 30 seconds - val lines = (1 to numStreams).map(_ => { - ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) - }) - val union = ssc.union(lines) - val counts = union.mapPartitions(splitAndCountPartitions) - val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - windowedCounts.foreach(r => println("# unique words = " + r.count())) - - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala deleted file mode 100644 index 4c6e08bc74..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala +++ /dev/null @@ -1,85 +0,0 @@ -package spark.streaming.examples.clickstream - -import java.net.{InetAddress,ServerSocket,Socket,SocketException} -import java.io.{InputStreamReader, BufferedReader, PrintWriter} -import util.Random - -/** Represents a page view on a website with associated dimension data.*/ -class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) { - override def toString() : String = { - "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) - } -} -object PageView { - def fromString(in : String) : PageView = { - val parts = in.split("\t") - new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) - } -} - -/** Generates streaming events to simulate page views on a website. - * - * This should be used in tandem with PageViewStream.scala. Example: - * $ ./run spark.streaming.examples.clickstream.PageViewGenerator 44444 10 - * $ ./run spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 - * */ -object PageViewGenerator { - val pages = Map("http://foo.com/" -> .7, - "http://foo.com/news" -> 0.2, - "http://foo.com/contact" -> .1) - val httpStatus = Map(200 -> .95, - 404 -> .05) - val userZipCode = Map(94709 -> .5, - 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01):_*) - - - def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { - val rand = new Random().nextDouble() - var total = 0.0 - for ((item, prob) <- inputMap) { - total = total + prob - if (total > rand) { - return item - } - } - return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 - } - - def getNextClickEvent() : String = { - val id = pickFromDistribution(userID) - val page = pickFromDistribution(pages) - val status = pickFromDistribution(httpStatus) - val zipCode = pickFromDistribution(userZipCode) - new PageView(page, status, zipCode, id).toString() - } - - def main(args : Array[String]) { - if (args.length != 2) { - System.err.println("Usage: PageViewGenerator ") - System.exit(1) - } - val port = args(0).toInt - val viewsPerSecond = args(1).toFloat - val sleepDelayMs = (1000.0 / viewsPerSecond).toInt - val listener = new ServerSocket(port) - println("Listening on port: " + port) - - while (true) { - val socket = listener.accept() - new Thread() { - override def run = { - println("Got client connected from: " + socket.getInetAddress) - val out = new PrintWriter(socket.getOutputStream(), true) - - while (true) { - Thread.sleep(sleepDelayMs) - out.write(getNextClickEvent()) - out.flush() - } - socket.close() - } - }.start() - } - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala deleted file mode 100644 index a191321d91..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala +++ /dev/null @@ -1,84 +0,0 @@ -package spark.streaming.examples.clickstream - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ -import spark.SparkContext._ - -/** Analyses a streaming dataset of web page views. This class demonstrates several types of - * operators available in Spark streaming. - * - * This should be used in tandem with PageViewStream.scala. Example: - * $ ./run spark.streaming.examples.clickstream.PageViewGenerator 44444 10 - * $ ./run spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 - * */ -object PageViewStream { - def main(args: Array[String]) { - if (args.length != 3) { - System.err.println("Usage: PageViewStream ") - System.err.println(" must be one of pageCounts, slidingPageCounts," + - " errorRatePerZipCode, activeUserCount, popularUsersSeen") - System.exit(1) - } - val metric = args(0) - val host = args(1) - val port = args(2).toInt - - // Create the context - val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1)) - - // Create a NetworkInputDStream on target host:port and convert each line to a PageView - val pageViews = ssc.networkTextStream(host, port) - .flatMap(_.split("\n")) - .map(PageView.fromString(_)) - - // Return a count of views per URL seen in each batch - val pageCounts = pageViews.map(view => ((view.url, 1))).countByKey() - - // Return a sliding window of page views per URL in the last ten seconds - val slidingPageCounts = pageViews.map(view => ((view.url, 1))) - .window(Seconds(10), Seconds(2)) - .countByKey() - - - // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds - val statusesPerZipCode = pageViews.window(Seconds(30), Seconds(2)) - .map(view => ((view.zipCode, view.status))) - .groupByKey() - val errorRatePerZipCode = statusesPerZipCode.map{ - case(zip, statuses) => - val normalCount = statuses.filter(_ == 200).size - val errorCount = statuses.size - normalCount - val errorRatio = errorCount.toFloat / statuses.size - if (errorRatio > 0.05) {"%s: **%s**".format(zip, errorRatio)} - else {"%s: %s".format(zip, errorRatio)} - } - - // Return the number unique users in last 15 seconds - val activeUserCount = pageViews.window(Seconds(15), Seconds(2)) - .map(view => (view.userID, 1)) - .groupByKey() - .count() - .map("Unique active users: " + _) - - // An external dataset we want to join to this stream - val userList = ssc.sc.parallelize( - Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) - - metric match { - case "pageCounts" => pageCounts.print() - case "slidingPageCounts" => slidingPageCounts.print() - case "errorRatePerZipCode" => errorRatePerZipCode.print() - case "activeUserCount" => activeUserCount.print() - case "popularUsersSeen" => - // Look for users in our existing dataset and print it out if we have a match - pageViews.map(view => (view.userID, 1)) - .foreach((rdd, time) => rdd.join(userList) - .map(_._2._2) - .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) - } - - ssc.start() - } -} -- cgit v1.2.3 From 1346126485444afc065bf4951c4bedebe5c95ce4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 12:11:27 -0800 Subject: Changed cleanup to clearOldValues for TimeStampedHashMap and TimeStampedHashSet. --- core/src/main/scala/spark/CacheTracker.scala | 4 ++-- core/src/main/scala/spark/MapOutputTracker.scala | 4 ++-- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 6 +++--- core/src/main/scala/spark/scheduler/ResultTask.scala | 2 +- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 2 +- core/src/main/scala/spark/util/TimeStampedHashMap.scala | 7 +++++-- core/src/main/scala/spark/util/TimeStampedHashSet.scala | 5 ++++- 7 files changed, 18 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 7d320c4fe5..86ad737583 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.cleanup) + private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues) private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) @@ -120,7 +120,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b // Remembers which splits are currently being loaded (on worker nodes) val loading = new HashSet[String] - val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.cleanup) + val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 5ebdba0fc8..a2fa2d1ea7 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -178,8 +178,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def cleanup(cleanupTime: Long) { - mapStatuses.cleanup(cleanupTime) - cachedSerializedStatuses.cleanup(cleanupTime) + mapStatuses.clearOldValues(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) } def stop() { diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 9387ba19a3..59f2099e91 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -599,15 +599,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def cleanup(cleanupTime: Long) { var sizeBefore = idToStage.size - idToStage.cleanup(cleanupTime) + idToStage.clearOldValues(cleanupTime) logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) sizeBefore = shuffleToMapStage.size - shuffleToMapStage.cleanup(cleanupTime) + shuffleToMapStage.clearOldValues(cleanupTime) logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) sizeBefore = pendingTasks.size - pendingTasks.cleanup(cleanupTime) + pendingTasks.clearOldValues(cleanupTime) logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) } diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 7ec6564105..74a63c1af1 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -12,7 +12,7 @@ private[spark] object ResultTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.cleanup) + val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues) def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index feb63abb61..19f5328eee 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -23,7 +23,7 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.cleanup) + val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 7e785182ea..bb7c5c01c8 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.Map /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion * time stamp along with each key-value pair. Key-value pairs that are older than a particular - * threshold time can them be removed using the cleanup method. This is intended to be a drop-in + * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in * replacement of scala.collection.mutable.HashMap. */ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { @@ -74,7 +74,10 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { } } - def cleanup(threshTime: Long) { + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { val iterator = internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala index 539dd75844..5f1cc93752 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashSet.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala @@ -52,7 +52,10 @@ class TimeStampedHashSet[A] extends Set[A] { } } - def cleanup(threshTime: Long) { + /** + * Removes old values that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { val iterator = internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() -- cgit v1.2.3 From 9c32f300fb4151a2b563bf3d2e46469722e016e1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 7 Jan 2013 16:50:23 -0500 Subject: Add Accumulable.setValue for easier use in Java --- core/src/main/scala/spark/Accumulators.scala | 20 +++++++++++++++----- core/src/test/scala/spark/JavaAPISuite.java | 4 ++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index 6280f25391..b644aba5f8 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -63,9 +63,12 @@ class Accumulable[R, T] ( /** * Access the accumulator's current value; only allowed on master. */ - def value = { - if (!deserialized) value_ - else throw new UnsupportedOperationException("Can't read accumulator value in task") + def value: R = { + if (!deserialized) { + value_ + } else { + throw new UnsupportedOperationException("Can't read accumulator value in task") + } } /** @@ -82,10 +85,17 @@ class Accumulable[R, T] ( /** * Set the accumulator's value; only allowed on master. */ - def value_= (r: R) { - if (!deserialized) value_ = r + def value_= (newValue: R) { + if (!deserialized) value_ = newValue else throw new UnsupportedOperationException("Can't assign accumulator value in task") } + + /** + * Set the accumulator's value; only allowed on master + */ + def setValue(newValue: R) { + this.value = newValue + } // Called by Java when deserializing an object private def readObject(in: ObjectInputStream) { diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 912f8de05d..0817d1146c 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -624,5 +624,9 @@ public class JavaAPISuite implements Serializable { } }); Assert.assertEquals((Float) 25.0f, floatAccum.value()); + + // Test the setValue method + floatAccum.setValue(5.0f); + Assert.assertEquals((Float) 5.0f, floatAccum.value()); } } -- cgit v1.2.3 From 237bac36e9dca8828192994dad323b8da1619267 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 14:37:21 -0800 Subject: Renamed examples and added documentation. --- core/src/main/scala/spark/RDDCheckpointData.scala | 4 +- docs/streaming-programming-guide.md | 14 ++-- .../spark/streaming/examples/FileStream.scala | 46 ------------- .../examples/FileStreamWithCheckpoint.scala | 75 ---------------------- .../spark/streaming/examples/FlumeEventCount.scala | 2 +- .../scala/spark/streaming/examples/GrepRaw.scala | 32 --------- .../spark/streaming/examples/HdfsWordCount.scala | 36 +++++++++++ .../streaming/examples/NetworkWordCount.scala | 36 +++++++++++ .../spark/streaming/examples/RawNetworkGrep.scala | 46 +++++++++++++ .../streaming/examples/TopKWordCountRaw.scala | 49 -------------- .../spark/streaming/examples/WordCountHdfs.scala | 25 -------- .../streaming/examples/WordCountNetwork.scala | 25 -------- .../spark/streaming/examples/WordCountRaw.scala | 43 ------------- .../scala/spark/streaming/StreamingContext.scala | 38 ++++++++--- .../spark/streaming/dstream/FileInputDStream.scala | 16 ++--- .../scala/spark/streaming/InputStreamsSuite.scala | 2 +- 16 files changed, 163 insertions(+), 326 deletions(-) delete mode 100644 examples/src/main/scala/spark/streaming/examples/FileStream.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/GrepRaw.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 7af830940f..e270b6312e 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -65,7 +65,7 @@ extends Logging with Serializable { cpRDD = Some(newRDD) rdd.changeDependencies(newRDD) cpState = Checkpointed - RDDCheckpointData.checkpointCompleted() + RDDCheckpointData.clearTaskCaches() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } } @@ -90,7 +90,7 @@ extends Logging with Serializable { } private[spark] object RDDCheckpointData { - def checkpointCompleted() { + def clearTaskCaches() { ShuffleMapTask.clearCache() ResultTask.clearCache() } diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index fc2ea2ef79..05a88ce7bd 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -187,8 +187,8 @@ Conversely, the computation can be stopped by using ssc.stop() {% endhighlight %} -# Example - WordCountNetwork.scala -A good example to start off is the spark.streaming.examples.WordCountNetwork. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in /streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala. +# Example - NetworkWordCount.scala +A good example to start off is the spark.streaming.examples.NetworkWordCount. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in /streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala. {% highlight scala %} import spark.streaming.{Seconds, StreamingContext} @@ -196,7 +196,7 @@ import spark.streaming.StreamingContext._ ... // Create the context and set up a network input stream to receive from a host:port -val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) +val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1)) val lines = ssc.networkTextStream(args(1), args(2).toInt) // Split the lines into words, count them, and print some of the counts on the master @@ -214,13 +214,13 @@ To run this example on your local machine, you need to first run a Netcat server $ nc -lk 9999 {% endhighlight %} -Then, in a different terminal, you can start WordCountNetwork by using +Then, in a different terminal, you can start NetworkWordCount by using {% highlight bash %} -$ ./run spark.streaming.examples.WordCountNetwork local[2] localhost 9999 +$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999 {% endhighlight %} -This will make WordCountNetwork connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen. +This will make NetworkWordCount connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen. - + - + @@ -88,55 +88,60 @@ DStreams support many of the transformations available on normal Spark RDD's: - + - + + + + +
    @@ -240,7 +240,7 @@ hello world {% highlight bash %} -# TERMINAL 2: RUNNING WordCountNetwork +# TERMINAL 2: RUNNING NetworkWordCount ... 2012-12-31 18:47:10,446 INFO SparkContext: Job finished: run at ThreadPoolExecutor.java:886, took 0.038817 s ------------------------------------------- diff --git a/examples/src/main/scala/spark/streaming/examples/FileStream.scala b/examples/src/main/scala/spark/streaming/examples/FileStream.scala deleted file mode 100644 index 81938d30d4..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/FileStream.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.StreamingContext -import spark.streaming.StreamingContext._ -import spark.streaming.Seconds -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - - -object FileStream { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: FileStream ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "FileStream", Seconds(1)) - - // Create the new directory - val directory = new Path(args(1)) - val fs = directory.getFileSystem(new Configuration()) - if (fs.exists(directory)) throw new Exception("This directory already exists") - fs.mkdirs(directory) - fs.deleteOnExit(directory) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val inputStream = ssc.textFileStream(directory.toString) - val words = inputStream.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - - // Creating new files in the directory - val text = "This is a text file" - for (i <- 1 to 30) { - ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10) - .saveAsTextFile(new Path(directory, i.toString).toString) - Thread.sleep(1000) - } - Thread.sleep(5000) // Waiting for the file to be processed - ssc.stop() - System.exit(0) - } -} \ No newline at end of file diff --git a/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala deleted file mode 100644 index b7bc15a1d5..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ /dev/null @@ -1,75 +0,0 @@ -package spark.streaming.examples - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration - -object FileStreamWithCheckpoint { - - def main(args: Array[String]) { - - if (args.size != 3) { - println("FileStreamWithCheckpoint ") - println("FileStreamWithCheckpoint restart ") - System.exit(-1) - } - - val directory = new Path(args(1)) - val checkpointDir = args(2) - - val ssc: StreamingContext = { - - if (args(0) == "restart") { - - // Recreated streaming context from specified checkpoint file - new StreamingContext(checkpointDir) - - } else { - - // Create directory if it does not exist - val fs = directory.getFileSystem(new Configuration()) - if (!fs.exists(directory)) fs.mkdirs(directory) - - // Create new streaming context - val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint", Seconds(1)) - ssc_.checkpoint(checkpointDir) - - // Setup the streaming computation - val inputStream = ssc_.textFileStream(directory.toString) - val words = inputStream.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - - ssc_ - } - } - - // Start the stream computation - startFileWritingThread(directory.toString) - ssc.start() - } - - def startFileWritingThread(directory: String) { - - val fs = new Path(directory).getFileSystem(new Configuration()) - - val fileWritingThread = new Thread() { - override def run() { - val r = new scala.util.Random() - val text = "This is a sample text file with a random number " - while(true) { - val number = r.nextInt() - val file = new Path(directory, number.toString) - val fos = fs.create(file) - fos.writeChars(text + number) - fos.close() - println("Created text file " + file) - Thread.sleep(1000) - } - } - } - fileWritingThread.start() - } - -} diff --git a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala index e60ce483a3..461929fba2 100644 --- a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala @@ -5,7 +5,7 @@ import spark.storage.StorageLevel import spark.streaming._ /** - * Produce a streaming count of events received from Flume. + * Produces a count of events received from Flume. * * This should be used in conjunction with an AvroSink in Flume. It will start * an Avro server on at the request host:port address and listen for requests. diff --git a/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala b/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala deleted file mode 100644 index 812faa368a..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/GrepRaw.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel - -import spark.streaming._ -import spark.streaming.util.RawTextHelper._ - -object GrepRaw { - def main(args: Array[String]) { - if (args.length != 5) { - System.err.println("Usage: GrepRaw ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - - // Create the context - val ssc = new StreamingContext(master, "GrepRaw", Milliseconds(batchMillis)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - warmUp(ssc.sc) - - - val rawStreams = (1 to numStreams).map(_ => - ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = ssc.union(rawStreams) - union.filter(_.contains("Alice")).count().foreach(r => - println("Grep count: " + r.collect().mkString)) - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala new file mode 100644 index 0000000000..8530f5c175 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala @@ -0,0 +1,36 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + + +/** + * Counts words in new text files created in the given directory + * Usage: HdfsWordCount + * is the Spark master URL. + * is the directory that Spark Streaming will use to find and read new text files. + * + * To run this on your local machine on directory `localdir`, run this example + * `$ ./run spark.streaming.examples.HdfsWordCount local[2] localdir` + * Then create a text file in `localdir` and the words in the file will get counted. + */ +object HdfsWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: HdfsWordCount ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2)) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val lines = ssc.textFileStream(args(1)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} + diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala new file mode 100644 index 0000000000..43c01d5db2 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala @@ -0,0 +1,36 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: NetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999` + */ +object NetworkWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: NetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + // Create the context and set the batch size + val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1)) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.networkTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala new file mode 100644 index 0000000000..2eec777c54 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala @@ -0,0 +1,46 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel + +import spark.streaming._ +import spark.streaming.util.RawTextHelper + +/** + * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited + * lines have the word 'the' in them. This is useful for benchmarking purposes. This + * will only work with spark.streaming.util.RawTextSender running on all worker nodes + * and with Spark using Kryo serialization (set Java property "spark.serializer" to + * "spark.KryoSerializer"). + * Usage: RawNetworkGrep + * is the Spark master URL + * is the number rawNetworkStreams, which should be same as number + * of work nodes in the cluster + * is "localhost". + * is the port on which RawTextSender is running in the worker nodes. + * is the Spark Streaming batch duration in milliseconds. + */ + +object RawNetworkGrep { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: RawNetworkGrep ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context + val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis)) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + RawTextHelper.warmUp(ssc.sc) + + val rawStreams = (1 to numStreams).map(_ => + ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray + val union = ssc.union(rawStreams) + union.filter(_.contains("the")).count().foreach(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala deleted file mode 100644 index 338834bc3c..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.streaming.examples - -import spark.storage.StorageLevel -import spark.util.IntParam - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.streaming.util.RawTextHelper._ - -import java.util.UUID - -object TopKWordCountRaw { - - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: WordCountRaw <# streams> ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - val k = 10 - - // Create the context, and set the checkpoint directory. - // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts - // periodically to HDFS - val ssc = new StreamingContext(master, "TopKWordCountRaw", Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - /*warmUp(ssc.sc)*/ - - // Set up the raw network streams that will connect to localhost:port to raw test - // senders on the slaves and generate top K words of last 30 seconds - val lines = (1 to numStreams).map(_ => { - ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) - }) - val union = ssc.union(lines) - val counts = union.mapPartitions(splitAndCountPartitions) - val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k)) - partialTopKWindowedCounts.foreach(rdd => { - val collectedCounts = rdd.collect - println("Collected " + collectedCounts.size + " words from partial top words") - println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(",")) - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala deleted file mode 100644 index 867a8f42c4..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/WordCountHdfs.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -object WordCountHdfs { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCountHdfs ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "WordCountHdfs", Seconds(2)) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val lines = ssc.textFileStream(args(1)) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} - diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala deleted file mode 100644 index eadda60563..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/WordCountNetwork.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -object WordCountNetwork { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: WordCountNetwork \n" + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - // Create the context and set the batch size - val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1)) - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.networkTextStream(args(1), args(2).toInt) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala deleted file mode 100644 index d93335a8ce..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ /dev/null @@ -1,43 +0,0 @@ -package spark.streaming.examples - -import spark.storage.StorageLevel -import spark.util.IntParam - -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.streaming.util.RawTextHelper._ - -import java.util.UUID - -object WordCountRaw { - - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: WordCountRaw <# streams> ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args - - // Create the context, and set the checkpoint directory. - // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts - // periodically to HDFS - val ssc = new StreamingContext(master, "WordCountRaw", Seconds(1)) - ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1)) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - warmUp(ssc.sc) - - // Set up the raw network streams that will connect to localhost:port to raw test - // senders on the slaves and generate count of words of last 30 seconds - val lines = (1 to numStreams).map(_ => { - ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2) - }) - val union = ssc.union(lines) - val counts = union.mapPartitions(splitAndCountPartitions) - val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10) - windowedCounts.foreach(r => println("# unique words = " + r.count())) - - ssc.start() - } -} diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 7256e41af9..215246ba2e 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -154,7 +154,7 @@ class StreamingContext private ( storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } @@ -192,7 +192,7 @@ class StreamingContext private ( storageLevel: StorageLevel ): DStream[T] = { val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } @@ -208,7 +208,7 @@ class StreamingContext private ( storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[SparkFlumeEvent] = { val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } @@ -228,13 +228,14 @@ class StreamingContext private ( storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[T] = { val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } /** * Creates a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. + * File names starting with . are ignored. * @param directory HDFS directory to monitor for new file * @tparam K Key type for reading HDFS file * @tparam V Value type for reading HDFS file @@ -244,16 +245,37 @@ class StreamingContext private ( K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K, V]: ClassManifest - ](directory: String): DStream[(K, V)] = { + ] (directory: String): DStream[(K, V)] = { val inputStream = new FileInputDStream[K, V, F](this, directory) - graph.addInputStream(inputStream) + registerInputStream(inputStream) + inputStream + } + + /** + * Creates a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * @param directory HDFS directory to monitor for new file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[ + K: ClassManifest, + V: ClassManifest, + F <: NewInputFormat[K, V]: ClassManifest + ] (directory: String, filter: Path => Boolean, newFilesOnly: Boolean): DStream[(K, V)] = { + val inputStream = new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly) + registerInputStream(inputStream) inputStream } + /** * Creates a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as text files (using key as LongWritable, value - * as Text and input format as TextInputFormat). + * as Text and input format as TextInputFormat). File names starting with . are ignored. * @param directory HDFS directory to monitor for new file */ def textFileStream(directory: String): DStream[String] = { @@ -274,7 +296,7 @@ class StreamingContext private ( defaultRDD: RDD[T] = null ): DStream[T] = { val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) - graph.addInputStream(inputStream) + registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala index cf72095324..1e6ad84b44 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -14,7 +14,7 @@ private[streaming] class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( @transient ssc_ : StreamingContext, directory: String, - filter: PathFilter = FileInputDStream.defaultPathFilter, + filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true) extends InputDStream[(K, V)](ssc_) { @@ -60,7 +60,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K val latestModTimeFiles = new HashSet[String]() def accept(path: Path): Boolean = { - if (!filter.accept(path)) { + if (!filter(path)) { return false } else { val modTime = fs.getFileStatus(path).getModificationTime() @@ -95,16 +95,8 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } +private[streaming] object FileInputDStream { - val defaultPathFilter = new PathFilter with Serializable { - def accept(path: Path): Boolean = { - val file = path.getName() - if (file.startsWith(".") || file.endsWith("_tmp")) { - return false - } else { - return true - } - } - } + def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 76b528bec3..00ee903c1e 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -318,7 +318,7 @@ class TestServer(port: Int) extends Logging { } } } catch { - case e: SocketException => println(e) + case e: SocketException => logInfo(e) } finally { logInfo("Connection closed") if (!clientSocket.isClosed) clientSocket.close() -- cgit v1.2.3 From 3b0a3b89ac508b57b8afbd1ca7024ee558a5d1af Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 14:55:49 -0800 Subject: Added better docs for RDDCheckpointData --- core/src/main/scala/spark/RDDCheckpointData.scala | 10 +++++++++- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index e270b6312e..d845a522e4 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -14,15 +14,23 @@ private[spark] object CheckpointState extends Enumeration { } /** - * This class contains all the information of the regarding RDD checkpointing. + * This class contains all the information related to RDD checkpointing. Each instance of this class + * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, + * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations + * of the checkpointed RDD. */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) extends Logging with Serializable { import CheckpointState._ + // The checkpoint state of the associated RDD. var cpState = Initialized + + // The file to which the associated RDD has been checkpointed to @transient var cpFile: Option[String] = None + + // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. @transient var cpRDD: Option[RDD[T]] = None // Mark the RDD for checkpointing diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 1a88d402c3..86c63ca2f4 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -13,6 +13,10 @@ private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends override val index: Int = idx } +/** + * This RDD represents a RDD checkpoint file (similar to HadoopRDD). + */ +private[spark] class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) extends RDD[T](sc, Nil) { -- cgit v1.2.3 From e60514d79e753f38db06f8d3df20e113ba4dc11a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 15:16:16 -0800 Subject: Fixed bug --- streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 00ee903c1e..e71ba6ddc1 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -318,7 +318,7 @@ class TestServer(port: Int) extends Logging { } } } catch { - case e: SocketException => logInfo(e) + case e: SocketException => logError("TestServer error", e) } finally { logInfo("Connection closed") if (!clientSocket.isClosed) clientSocket.close() -- cgit v1.2.3 From f8d579a0c05b7d29b59e541b483ded471d14ec17 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 27 Dec 2012 13:30:07 -0800 Subject: Remove dependencies on sun jvm classes. Instead use reflection to infer HotSpot options and total physical memory size --- core/src/main/scala/spark/SizeEstimator.scala | 13 ++++++++++--- .../spark/deploy/worker/WorkerArguments.scala | 22 +++++++++++++++++++--- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala index 7c3e8640e9..d4e1157250 100644 --- a/core/src/main/scala/spark/SizeEstimator.scala +++ b/core/src/main/scala/spark/SizeEstimator.scala @@ -9,7 +9,6 @@ import java.util.Random import javax.management.MBeanServer import java.lang.management.ManagementFactory -import com.sun.management.HotSpotDiagnosticMXBean import scala.collection.mutable.ArrayBuffer @@ -76,12 +75,20 @@ private[spark] object SizeEstimator extends Logging { if (System.getProperty("spark.test.useCompressedOops") != null) { return System.getProperty("spark.test.useCompressedOops").toBoolean } + try { val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" val server = ManagementFactory.getPlatformMBeanServer() + + // NOTE: This should throw an exception in non-Sun JVMs + val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") + val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", + Class.forName("java.lang.String")) + val bean = ManagementFactory.newPlatformMXBeanProxy(server, - hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]) - return bean.getVMOption("UseCompressedOops").getValue.toBoolean + hotSpotMBeanName, hotSpotMBeanClass) + // TODO: We could use reflection on the VMOption returned ? + return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { case e: Exception => { // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 340920025b..37524a7c82 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -104,9 +104,25 @@ private[spark] class WorkerArguments(args: Array[String]) { } def inferDefaultMemory(): Int = { - val bean = ManagementFactory.getOperatingSystemMXBean - .asInstanceOf[com.sun.management.OperatingSystemMXBean] - val totalMb = (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt + val ibmVendor = System.getProperty("java.vendor").contains("IBM") + var totalMb = 0 + try { + val bean = ManagementFactory.getOperatingSystemMXBean() + if (ibmVendor) { + val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } else { + val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } + } catch { + case e: Exception => { + totalMb = 2*1024 + System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + } + } // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, 512) } -- cgit v1.2.3 From aed368a970bbaee4bdf297ba3f6f1b0fa131452c Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 29 Dec 2012 16:23:43 -0800 Subject: Update Hadoop dependency to 1.0.3 as 0.20 has Sun specific dependencies. Also fix SequenceFileRDDFunctions to pick the right type conversion across Hadoop versions --- core/src/main/scala/spark/SequenceFileRDDFunctions.scala | 8 +++++++- project/SparkBuild.scala | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index a34aee69c1..6b4a11d6d3 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -42,7 +42,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { classManifest[T].erasure } else { - implicitly[T => Writable].getClass.getMethods()(0).getReturnType + // We get the type of the Writable class by looking at the apply method which converts + // from T to Writable. Since we have two apply methods we filter out the one which + // is of the form "java.lang.Object apply(java.lang.Object)" + implicitly[T => Writable].getClass.getDeclaredMethods().filter( + m => m.getReturnType().toString != "java.lang.Object" && + m.getName() == "apply")(0).getReturnType + } // TODO: use something like WritableConverter to avoid reflection } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 842d0fa96b..7c7c33131a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -10,7 +10,7 @@ import twirl.sbt.TwirlPlugin._ object SparkBuild extends Build { // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or // "1.0.3" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop. - val HADOOP_VERSION = "0.20.205.0" + val HADOOP_VERSION = "1.0.3" val HADOOP_MAJOR_VERSION = "1" // For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2" -- cgit v1.2.3 From 77d751731ccd06e161e3ef10540f8165d964282f Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 29 Dec 2012 18:28:00 -0800 Subject: Remove unused BoundedMemoryCache file and associated test case. --- core/src/main/scala/spark/BoundedMemoryCache.scala | 118 --------------------- .../test/scala/spark/BoundedMemoryCacheSuite.scala | 58 ---------- 2 files changed, 176 deletions(-) delete mode 100644 core/src/main/scala/spark/BoundedMemoryCache.scala delete mode 100644 core/src/test/scala/spark/BoundedMemoryCacheSuite.scala diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala deleted file mode 100644 index e8392a194f..0000000000 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ /dev/null @@ -1,118 +0,0 @@ -package spark - -import java.util.LinkedHashMap - -/** - * An implementation of Cache that estimates the sizes of its entries and attempts to limit its - * total memory usage to a fraction of the JVM heap. Objects' sizes are estimated using - * SizeEstimator, which has limitations; most notably, we will overestimate total memory used if - * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well - * when most of the space is used by arrays of primitives or of simple classes. - */ -private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { - logInfo("BoundedMemoryCache.maxBytes = " + maxBytes) - - def this() { - this(BoundedMemoryCache.getMaxBytes) - } - - private var currentBytes = 0L - private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true) - - override def get(datasetId: Any, partition: Int): Any = { - synchronized { - val entry = map.get((datasetId, partition)) - if (entry != null) { - entry.value - } else { - null - } - } - } - - override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { - val key = (datasetId, partition) - logInfo("Asked to add key " + key) - val size = estimateValueSize(key, value) - synchronized { - if (size > getCapacity) { - return CachePutFailure() - } else if (ensureFreeSpace(datasetId, size)) { - logInfo("Adding key " + key) - map.put(key, new Entry(value, size)) - currentBytes += size - logInfo("Number of entries is now " + map.size) - return CachePutSuccess(size) - } else { - logInfo("Didn't add key " + key + " because we would have evicted part of same dataset") - return CachePutFailure() - } - } - } - - override def getCapacity: Long = maxBytes - - /** - * Estimate sizeOf 'value' - */ - private def estimateValueSize(key: (Any, Int), value: Any) = { - val startTime = System.currentTimeMillis - val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef]) - val timeTaken = System.currentTimeMillis - startTime - logInfo("Estimated size for key %s is %d".format(key, size)) - logInfo("Size estimation for key %s took %d ms".format(key, timeTaken)) - size - } - - /** - * Remove least recently used entries from the map until at least space bytes are free, in order - * to make space for a partition from the given dataset ID. If this cannot be done without - * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes - * that a lock is held on the BoundedMemoryCache. - */ - private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = { - logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format( - datasetId, space, currentBytes, maxBytes)) - val iter = map.entrySet.iterator // Will give entries in LRU order - while (maxBytes - currentBytes < space && iter.hasNext) { - val mapEntry = iter.next() - val (entryDatasetId, entryPartition) = mapEntry.getKey - if (entryDatasetId == datasetId) { - // Cannot make space without removing part of the same dataset, or a more recently used one - return false - } - reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue) - currentBytes -= mapEntry.getValue.size - iter.remove() - } - return true - } - - protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { - logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) - // TODO: remove BoundedMemoryCache - - val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)] - innerDatasetId match { - case rddId: Int => - SparkEnv.get.cacheTracker.dropEntry(rddId, partition) - case broadcastUUID: java.util.UUID => - // TODO: Maybe something should be done if the broadcasted variable falls out of cache - case _ => - } - } -} - -// An entry in our map; stores a cached object and its size in bytes -private[spark] case class Entry(value: Any, size: Long) - -private[spark] object BoundedMemoryCache { - /** - * Get maximum cache capacity from system configuration - */ - def getMaxBytes: Long = { - val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble - (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong - } -} - diff --git a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala b/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala deleted file mode 100644 index 37cafd1e8e..0000000000 --- a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala +++ /dev/null @@ -1,58 +0,0 @@ -package spark - -import org.scalatest.FunSuite -import org.scalatest.PrivateMethodTester -import org.scalatest.matchers.ShouldMatchers - -// TODO: Replace this with a test of MemoryStore -class BoundedMemoryCacheSuite extends FunSuite with PrivateMethodTester with ShouldMatchers { - test("constructor test") { - val cache = new BoundedMemoryCache(60) - expect(60)(cache.getCapacity) - } - - test("caching") { - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - val oldArch = System.setProperty("os.arch", "amd64") - val oldOops = System.setProperty("spark.test.useCompressedOops", "true") - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - - val cache = new BoundedMemoryCache(60) { - //TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry' - override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { - logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) - } - } - - // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length - // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. - // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html - // Work around to check for either. - - //should be OK - cache.put("1", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48))) - - //we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from - //cache because it's from the same dataset - expect(CachePutFailure())(cache.put("1", 1, "Meh")) - - //should be OK, dataset '1' can be evicted from cache - cache.put("2", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48))) - - //should fail, cache should obey it's capacity - expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string")) - - if (oldArch != null) { - System.setProperty("os.arch", oldArch) - } else { - System.clearProperty("os.arch") - } - - if (oldOops != null) { - System.setProperty("spark.test.useCompressedOops", oldOops) - } else { - System.clearProperty("spark.test.useCompressedOops") - } - } -} -- cgit v1.2.3 From 55c66d365f76f3e5ecc6b850ba81c84b320f6772 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 15:19:33 -0800 Subject: Use a dummy string class in Size Estimator tests to make it resistant to jdk versions --- core/src/test/scala/spark/SizeEstimatorSuite.scala | 33 ++++++++++++++-------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala index 17f366212b..bf3b2e1eed 100644 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala @@ -20,6 +20,15 @@ class DummyClass4(val d: DummyClass3) { val x: Int = 0 } +object DummyString { + def apply(str: String) : DummyString = new DummyString(str.toArray) +} +class DummyString(val arr: Array[Char]) { + override val hashCode: Int = 0 + // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f + @transient val hash32: Int = 0 +} + class SizeEstimatorSuite extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers { @@ -50,10 +59,10 @@ class SizeEstimatorSuite // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html // Work around to check for either. test("strings") { - SizeEstimator.estimate("") should (equal (48) or equal (40)) - SizeEstimator.estimate("a") should (equal (56) or equal (48)) - SizeEstimator.estimate("ab") should (equal (56) or equal (48)) - SizeEstimator.estimate("abcdefgh") should (equal(64) or equal(56)) + SizeEstimator.estimate(DummyString("")) should (equal (48) or equal (40)) + SizeEstimator.estimate(DummyString("a")) should (equal (56) or equal (48)) + SizeEstimator.estimate(DummyString("ab")) should (equal (56) or equal (48)) + SizeEstimator.estimate(DummyString("abcdefgh")) should (equal(64) or equal(56)) } test("primitive arrays") { @@ -105,10 +114,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - expect(40)(SizeEstimator.estimate("")) - expect(48)(SizeEstimator.estimate("a")) - expect(48)(SizeEstimator.estimate("ab")) - expect(56)(SizeEstimator.estimate("abcdefgh")) + expect(40)(SizeEstimator.estimate(DummyString(""))) + expect(48)(SizeEstimator.estimate(DummyString("a"))) + expect(48)(SizeEstimator.estimate(DummyString("ab"))) + expect(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) } @@ -124,10 +133,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - SizeEstimator.estimate("") should (equal (64) or equal (56)) - SizeEstimator.estimate("a") should (equal (72) or equal (64)) - SizeEstimator.estimate("ab") should (equal (72) or equal (64)) - SizeEstimator.estimate("abcdefgh") should (equal (80) or equal (72)) + SizeEstimator.estimate(DummyString("")) should (equal (64) or equal (56)) + SizeEstimator.estimate(DummyString("a")) should (equal (72) or equal (64)) + SizeEstimator.estimate(DummyString("ab")) should (equal (72) or equal (64)) + SizeEstimator.estimate(DummyString("abcdefgh")) should (equal (80) or equal (72)) resetOrClear("os.arch", arch) resetOrClear("spark.test.useCompressedOops", oops) -- cgit v1.2.3 From 4719e6d8fe6d93734f5bbe6c91dcc4616c1ed317 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 16:06:07 -0800 Subject: Changed locations for unit test logs. --- bagel/src/test/resources/log4j.properties | 4 ++-- core/src/test/resources/log4j.properties | 4 ++-- repl/src/test/resources/log4j.properties | 4 ++-- streaming/src/test/resources/log4j.properties | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 4c99e450bc..83d05cab2f 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the console +# Set everything to be logged to the file bagel/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=spark-tests.log +log4j.appender.file.file=bagel/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 5ed388e91b..6ec89c0184 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the file spark-tests.log +# Set everything to be logged to the file core/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=spark-tests.log +log4j.appender.file.file=core/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index 4c99e450bc..cfb1a390e6 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the console +# Set everything to be logged to the repl/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=spark-tests.log +log4j.appender.file.file=repl/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 33bafebaab..edfa1243fa 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,8 +1,8 @@ -# Set everything to be logged to the file streaming-tests.log +# Set everything to be logged to the file streaming/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false -log4j.appender.file.file=streaming-tests.log +log4j.appender.file.file=streaming/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n -- cgit v1.2.3 From fb3d4d5e85cd4b094411bb08a32ab50cc62dc151 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 16:46:06 -0800 Subject: Make default hadoop version 1.0.3 in pom.xml --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index b33cee26b8..fe5b1d0ee4 100644 --- a/pom.xml +++ b/pom.xml @@ -489,7 +489,7 @@ org.apache.hadoop hadoop-core - 0.20.205.0 + 1.0.3 -- cgit v1.2.3 From b1336e2fe458b92dcf60dcd249c41c7bdcc8be6d Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 17:00:32 -0800 Subject: Update expected size of strings to match our dummy string class --- core/src/test/scala/spark/SizeEstimatorSuite.scala | 31 +++++++++------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala index bf3b2e1eed..e235ef2f67 100644 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala @@ -3,7 +3,6 @@ package spark import org.scalatest.FunSuite import org.scalatest.BeforeAndAfterAll import org.scalatest.PrivateMethodTester -import org.scalatest.matchers.ShouldMatchers class DummyClass1 {} @@ -30,7 +29,7 @@ class DummyString(val arr: Array[Char]) { } class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers { + extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { var oldArch: String = _ var oldOops: String = _ @@ -54,15 +53,13 @@ class SizeEstimatorSuite expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3))) } - // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length. - // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. - // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html - // Work around to check for either. + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("strings") { - SizeEstimator.estimate(DummyString("")) should (equal (48) or equal (40)) - SizeEstimator.estimate(DummyString("a")) should (equal (56) or equal (48)) - SizeEstimator.estimate(DummyString("ab")) should (equal (56) or equal (48)) - SizeEstimator.estimate(DummyString("abcdefgh")) should (equal(64) or equal(56)) + expect(40)(SizeEstimator.estimate(DummyString(""))) + expect(48)(SizeEstimator.estimate(DummyString("a"))) + expect(48)(SizeEstimator.estimate(DummyString("ab"))) + expect(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) } test("primitive arrays") { @@ -122,10 +119,8 @@ class SizeEstimatorSuite resetOrClear("os.arch", arch) } - // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length. - // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. - // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html - // Work around to check for either. + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("64-bit arch with no compressed oops") { val arch = System.setProperty("os.arch", "amd64") val oops = System.setProperty("spark.test.useCompressedOops", "false") @@ -133,10 +128,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - SizeEstimator.estimate(DummyString("")) should (equal (64) or equal (56)) - SizeEstimator.estimate(DummyString("a")) should (equal (72) or equal (64)) - SizeEstimator.estimate(DummyString("ab")) should (equal (72) or equal (64)) - SizeEstimator.estimate(DummyString("abcdefgh")) should (equal (80) or equal (72)) + expect(56)(SizeEstimator.estimate(DummyString(""))) + expect(64)(SizeEstimator.estimate(DummyString("a"))) + expect(64)(SizeEstimator.estimate(DummyString("ab"))) + expect(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) resetOrClear("spark.test.useCompressedOops", oops) -- cgit v1.2.3 From 4bbe07e5ece81fa874d2412bcc165179313a7619 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 17:46:22 -0800 Subject: Activate hadoop1 profile by default for maven builds --- bagel/pom.xml | 3 +++ core/pom.xml | 5 ++++- examples/pom.xml | 3 +++ pom.xml | 3 +++ repl-bin/pom.xml | 3 +++ repl/pom.xml | 3 +++ 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/bagel/pom.xml b/bagel/pom.xml index a8256a6e8b..4ca643bbb7 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -45,6 +45,9 @@ hadoop1 + + true + org.spark-project diff --git a/core/pom.xml b/core/pom.xml index ae52c20657..cd789a7db0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -159,6 +159,9 @@ hadoop1 + + true + org.apache.hadoop @@ -267,4 +270,4 @@ - \ No newline at end of file + diff --git a/examples/pom.xml b/examples/pom.xml index 782c026d73..9e638c8284 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -45,6 +45,9 @@ hadoop1 + + true + org.spark-project diff --git a/pom.xml b/pom.xml index fe5b1d0ee4..0e2d93c170 100644 --- a/pom.xml +++ b/pom.xml @@ -481,6 +481,9 @@ hadoop1 + + true + 1 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index 0667b71cc7..aa9895eda2 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -70,6 +70,9 @@ hadoop1 + + true + hadoop1 diff --git a/repl/pom.xml b/repl/pom.xml index 114e3e9932..ba7a051310 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -72,6 +72,9 @@ hadoop1 + + true + hadoop1 -- cgit v1.2.3 From 8c1b87251210bb5553e6a3b6f9648b178b221a3b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Jan 2013 17:48:10 -0800 Subject: Moved Twitter example to the where the other examples are. --- .../streaming/examples/twitter/TwitterBasic.scala | 43 +++++++++++++++ .../examples/twitter/TwitterInputDStream.scala | 62 ++++++++++++++++++++++ .../spark/streaming/TwitterInputDStream.scala | 58 -------------------- .../spark/streaming/examples/TwitterBasic.scala | 46 ---------------- 4 files changed, 105 insertions(+), 104 deletions(-) create mode 100644 examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala new file mode 100644 index 0000000000..22a927e87f --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala @@ -0,0 +1,43 @@ +package spark.streaming.examples.twitter + +import spark.streaming.{Seconds, StreamingContext} +import spark.streaming.StreamingContext._ + +object TwitterBasic { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: TwitterBasic " + + " [filter1] [filter2] ... [filter n]") + System.exit(1) + } + + val Array(master, username, password) = args.slice(0, 3) + val filters = args.slice(3, args.length) + + val ssc = new StreamingContext(master, "TwitterBasic", Seconds(2)) + val stream = new TwitterInputDStream(ssc, username, password, filters) + ssc.graph.addInputStream(stream) + + val hashTags = stream.flatMap( + status => status.getText.split(" ").filter(_.startsWith("#"))) + + // Word count over hashtags + val counts = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) + // TODO: Sorts on one node - should do with global sorting once streaming supports it + counts.foreach(rdd => { + val topList = rdd.collect().sortBy(-_._2).take(5) + if (!topList.isEmpty) { + println("\nPopular topics in last 60 seconds:") + topList.foreach{case (tag, count) => println("%s (%s tweets)".format(tag, count))} + } + }) + + // Print number of tweets in the window + stream.window(Seconds(60)).count().foreach(rdd => + if (rdd.count() != 0) { + println("Window size: %s tweets".format(rdd.take(1)(0))) + } + ) + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala new file mode 100644 index 0000000000..1e842d2c8e --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala @@ -0,0 +1,62 @@ +package spark.streaming.examples.twitter + +import spark.RDD +import spark.streaming._ +import spark.streaming.dstream.InputDStream +import spark.streaming.StreamingContext._ + +import twitter4j._ +import twitter4j.auth.BasicAuthorization +import collection.mutable.ArrayBuffer +import collection.JavaConversions._ + +/* A stream of Twitter statuses, potentially filtered by one or more keywords. +* +* @constructor create a new Twitter stream using the supplied username and password to authenticate. +* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is +* such that this may return a sampled subset of all tweets during each interval. +*/ +class TwitterInputDStream( + @transient ssc_ : StreamingContext, + username: String, + password: String, + filters: Seq[String] + ) extends InputDStream[Status](ssc_) { + val statuses: ArrayBuffer[Status] = ArrayBuffer() + var twitterStream: TwitterStream = _ + + override def start() = { + twitterStream = new TwitterStreamFactory() + .getInstance(new BasicAuthorization(username, password)) + twitterStream.addListener(new StatusListener { + def onStatus(status: Status) = { + statuses += status + } + // Unimplemented + def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {} + def onTrackLimitationNotice(i: Int) {} + def onScrubGeo(l: Long, l1: Long) {} + def onStallWarning(stallWarning: StallWarning) {} + def onException(e: Exception) {} + }) + + val query: FilterQuery = new FilterQuery + if (filters.size > 0) { + query.track(filters.toArray) + twitterStream.filter(query) + } else { + twitterStream.sample() + } + } + + override def stop() = { + twitterStream.shutdown() + } + + override def compute(validTime: Time): Option[RDD[Status]] = { + // Flush the current tweet buffer + val rdd = Some(ssc.sc.parallelize(statuses)) + statuses.foreach(x => statuses -= x) + rdd + } +} diff --git a/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala deleted file mode 100644 index adf1ed15c9..0000000000 --- a/streaming/src/main/scala/spark/streaming/TwitterInputDStream.scala +++ /dev/null @@ -1,58 +0,0 @@ -package spark.streaming - -import spark.RDD -import twitter4j._ -import twitter4j.auth.BasicAuthorization -import collection.mutable.ArrayBuffer -import collection.JavaConversions._ - -/* A stream of Twitter statuses, potentially filtered by one or more keywords. -* -* @constructor create a new Twitter stream using the supplied username and password to authenticate. -* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is -* such that this may return a sampled subset of all tweets during each interval. -*/ -class TwitterInputDStream( - @transient ssc_ : StreamingContext, - username: String, - password: String, - filters: Seq[String] - ) extends InputDStream[Status](ssc_) { - val statuses: ArrayBuffer[Status] = ArrayBuffer() - var twitterStream: TwitterStream = _ - - override def start() = { - twitterStream = new TwitterStreamFactory() - .getInstance(new BasicAuthorization(username, password)) - twitterStream.addListener(new StatusListener { - def onStatus(status: Status) = { - statuses += status - } - // Unimplemented - def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {} - def onTrackLimitationNotice(i: Int) {} - def onScrubGeo(l: Long, l1: Long) {} - def onStallWarning(stallWarning: StallWarning) {} - def onException(e: Exception) {} - }) - - val query: FilterQuery = new FilterQuery - if (filters.size > 0) { - query.track(filters.toArray) - twitterStream.filter(query) - } else { - twitterStream.sample() - } - } - - override def stop() = { - twitterStream.shutdown() - } - - override def compute(validTime: Time): Option[RDD[Status]] = { - // Flush the current tweet buffer - val rdd = Some(ssc.sc.parallelize(statuses)) - statuses.foreach(x => statuses -= x) - rdd - } -} diff --git a/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala b/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala deleted file mode 100644 index 19b3cad6ad..0000000000 --- a/streaming/src/main/scala/spark/streaming/examples/TwitterBasic.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.streaming.examples - -import spark.streaming.StreamingContext._ -import spark.streaming.{TwitterInputDStream, Seconds, StreamingContext} - -object TwitterBasic { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: TwitterBasic " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } - - val Array(master, username, password) = args.slice(0, 3) - val filters = args.slice(3, args.length) - - val ssc = new StreamingContext(master, "TwitterBasic", Seconds(2)) - val stream = new TwitterInputDStream(ssc, username, password, filters) - ssc.graph.addInputStream(stream) - - val hashTags = stream.flatMap( - status => status.getText.split(" ").filter(_.startsWith("#"))) - - // Word count over hashtags - val counts = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) - // TODO: Sorts on one node - should do with global sorting once streaming supports it - val topCounts = counts.collect().map(_.sortBy(-_._2).take(5)) - - // Print popular hashtags - topCounts.foreachRDD(rdd => { - if (rdd.count() != 0) { - val topList = rdd.take(1)(0) - println("\nPopular topics in last 60 seconds:") - topList.foreach{case (tag, count) => println("%s (%s tweets)".format(tag, count))} - } - }) - - // Print number of tweets in the window - stream.window(Seconds(60)).count().foreachRDD(rdd => - if (rdd.count() != 0) { - println("Window size: %s tweets".format(rdd.take(1)(0))) - } - ) - ssc.start() - } -} -- cgit v1.2.3 From c41042c816c2d6299aa7d93529b7c39db5d5c03a Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Wed, 26 Dec 2012 15:52:51 -0800 Subject: Log preferred hosts --- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index cf4aae03a7..dda7a6c64a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,7 +201,9 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" else "non-preferred" + val prefStr = if (preferred) "preferred" else + "non-preferred, not one of " + + task.preferredLocations.mkString(", ") logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping -- cgit v1.2.3 From 4725b0f6439337c7a0f5f6fc7034c6f6b9488ae9 Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 7 Jan 2013 20:07:08 -0800 Subject: Fixing if/else coding style for preferred hosts logging --- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index dda7a6c64a..a842afcdeb 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,9 +201,8 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" else - "non-preferred, not one of " + - task.preferredLocations.mkString(", ") + val prefStr = if (preferred) "preferred" + else "non-preferred, not one of " + task.preferredLocations.mkString(", ") logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping -- cgit v1.2.3 From 6c502e37934c737494c7288e63f0ed82b21604b5 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 7 Jan 2013 21:56:24 -0800 Subject: Making the Twitter example distributed. This adds a distributed (receiver-based) implementation of the Twitter dstream. It also changes the example to perform a distributed sort rather than collecting the dataset at one node. --- .../streaming/examples/twitter/TwitterBasic.scala | 55 ++++++++++++++-------- .../examples/twitter/TwitterInputDStream.scala | 44 ++++++++++------- 2 files changed, 62 insertions(+), 37 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala index 22a927e87f..377bc0c98e 100644 --- a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala +++ b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala @@ -1,8 +1,15 @@ package spark.streaming.examples.twitter -import spark.streaming.{Seconds, StreamingContext} import spark.streaming.StreamingContext._ +import spark.streaming.{Seconds, StreamingContext} +import spark.SparkContext._ +import spark.storage.StorageLevel +/** + * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter + * stream. The stream is instantiated with credentials and optionally filters supplied by the + * command line arguments. + */ object TwitterBasic { def main(args: Array[String]) { if (args.length < 3) { @@ -15,29 +22,39 @@ object TwitterBasic { val filters = args.slice(3, args.length) val ssc = new StreamingContext(master, "TwitterBasic", Seconds(2)) - val stream = new TwitterInputDStream(ssc, username, password, filters) - ssc.graph.addInputStream(stream) - - val hashTags = stream.flatMap( - status => status.getText.split(" ").filter(_.startsWith("#"))) - - // Word count over hashtags - val counts = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) - // TODO: Sorts on one node - should do with global sorting once streaming supports it - counts.foreach(rdd => { - val topList = rdd.collect().sortBy(-_._2).take(5) - if (!topList.isEmpty) { - println("\nPopular topics in last 60 seconds:") - topList.foreach{case (tag, count) => println("%s (%s tweets)".format(tag, count))} + val stream = new TwitterInputDStream(ssc, username, password, filters, + StorageLevel.MEMORY_ONLY_SER) + ssc.registerInputStream(stream) + + val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) + + val topCounts60 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) + .map{case (topic, count) => (count, topic)} + .transform(_.sortByKey(false)) + + val topCounts10 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(10)) + .map{case (topic, count) => (count, topic)} + .transform(_.sortByKey(false)) + + + // Print popular hashtags + topCounts60.foreach(rdd => { + if (rdd.count() != 0) { + val topList = rdd.take(5) + println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} } }) - // Print number of tweets in the window - stream.window(Seconds(60)).count().foreach(rdd => + topCounts10.foreach(rdd => { if (rdd.count() != 0) { - println("Window size: %s tweets".format(rdd.take(1)(0))) + val topList = rdd.take(5) + println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} } - ) + }) + ssc.start() } + } diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala index 1e842d2c8e..c7e4855f3b 100644 --- a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala +++ b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala @@ -1,13 +1,11 @@ package spark.streaming.examples.twitter -import spark.RDD +import spark._ import spark.streaming._ -import spark.streaming.dstream.InputDStream -import spark.streaming.StreamingContext._ - +import dstream.{NetworkReceiver, NetworkInputDStream} +import storage.StorageLevel import twitter4j._ import twitter4j.auth.BasicAuthorization -import collection.mutable.ArrayBuffer import collection.JavaConversions._ /* A stream of Twitter statuses, potentially filtered by one or more keywords. @@ -20,17 +18,31 @@ class TwitterInputDStream( @transient ssc_ : StreamingContext, username: String, password: String, - filters: Seq[String] - ) extends InputDStream[Status](ssc_) { - val statuses: ArrayBuffer[Status] = ArrayBuffer() + filters: Seq[String], + storageLevel: StorageLevel + ) extends NetworkInputDStream[Status](ssc_) { + + override def createReceiver(): NetworkReceiver[Status] = { + new TwitterReceiver(id, username, password, filters, storageLevel) + } +} + +class TwitterReceiver(streamId: Int, + username: String, + password: String, + filters: Seq[String], + storageLevel: StorageLevel + ) extends NetworkReceiver[Status](streamId) { var twitterStream: TwitterStream = _ + lazy val blockGenerator = new BlockGenerator(storageLevel) - override def start() = { + protected override def onStart() { + blockGenerator.start() twitterStream = new TwitterStreamFactory() .getInstance(new BasicAuthorization(username, password)) twitterStream.addListener(new StatusListener { def onStatus(status: Status) = { - statuses += status + blockGenerator += status } // Unimplemented def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {} @@ -47,16 +59,12 @@ class TwitterInputDStream( } else { twitterStream.sample() } + logInfo("Twitter receiver started") } - override def stop() = { + protected override def onStop() { + blockGenerator.stop() twitterStream.shutdown() - } - - override def compute(validTime: Time): Option[RDD[Status]] = { - // Flush the current tweet buffer - val rdd = Some(ssc.sc.parallelize(statuses)) - statuses.foreach(x => statuses -= x) - rdd + logInfo("Twitter receiver stopped") } } -- cgit v1.2.3 From f7adb382ace7f54c5093bf90574b3f9dd0d35534 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Tue, 8 Jan 2013 03:19:43 -0800 Subject: Activate hadoop1 if property hadoop is missing. hadoop2 can be activated now by using -Dhadoop -Phadoop2. --- bagel/pom.xml | 4 +++- core/pom.xml | 4 +++- examples/pom.xml | 4 +++- pom.xml | 4 +++- repl-bin/pom.xml | 4 +++- repl/pom.xml | 4 +++- 6 files changed, 18 insertions(+), 6 deletions(-) diff --git a/bagel/pom.xml b/bagel/pom.xml index 4ca643bbb7..85b2077026 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -46,7 +46,9 @@ hadoop1 - true + + !hadoop + diff --git a/core/pom.xml b/core/pom.xml index cd789a7db0..005d8fe498 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -160,7 +160,9 @@ hadoop1 - true + + !hadoop + diff --git a/examples/pom.xml b/examples/pom.xml index 9e638c8284..3f738a3f8c 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -46,7 +46,9 @@ hadoop1 - true + + !hadoop + diff --git a/pom.xml b/pom.xml index 0e2d93c170..ea5b9c9d05 100644 --- a/pom.xml +++ b/pom.xml @@ -482,7 +482,9 @@ hadoop1 - true + + !hadoop + 1 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index aa9895eda2..fecb01f3cd 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -71,7 +71,9 @@ hadoop1 - true + + !hadoop + hadoop1 diff --git a/repl/pom.xml b/repl/pom.xml index ba7a051310..04b2c35beb 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -73,7 +73,9 @@ hadoop1 - true + + !hadoop + hadoop1 -- cgit v1.2.3 From e4cb72da8a5428c6b9097e92ddbdf4ceee087b85 Mon Sep 17 00:00:00 2001 From: shane-huang Date: Tue, 8 Jan 2013 22:40:58 +0800 Subject: Fix an issue in ConnectionManager where sendingMessage may create too many unnecessary SendingConnections. --- core/src/main/scala/spark/network/Connection.scala | 7 +++++-- .../main/scala/spark/network/ConnectionManager.scala | 17 +++++++++-------- .../scala/spark/network/ConnectionManagerTest.scala | 18 +++++++++--------- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 80262ab7b4..95096fd0ba 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -135,8 +135,11 @@ extends Connection(SocketChannel.open, selector_) { val chunk = message.getChunkForSending(defaultChunkSize) if (chunk.isDefined) { messages += message // this is probably incorrect, it wont work as fifo - if (!message.started) logDebug("Starting to send [" + message + "]") - message.started = true + if (!message.started) { + logDebug("Starting to send [" + message + "]") + message.started = true + message.startTime = System.currentTimeMillis + } return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 642fa4b525..e7bd2d3bbd 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -43,12 +43,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(4) + val handleMessageExecutor = Executors.newFixedThreadPool(20) val serverChannel = ServerSocketChannel.open() val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new SynchronizedQueue[SendingConnection] + val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] @@ -78,11 +78,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def run() { try { - while(!selectorThread.isInterrupted) { - while(!connectionRequests.isEmpty) { - val sendingConnection = connectionRequests.dequeue + while(!selectorThread.isInterrupted) { + for( (connectionManagerId, sendingConnection) <- connectionRequests) { + //val sendingConnection = connectionRequests.dequeue sendingConnection.connect() addConnection(sendingConnection) + connectionRequests -= connectionManagerId } sendMessageRequests.synchronized { while(!sendMessageRequests.isEmpty) { @@ -300,8 +301,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector) - connectionRequests += newConnection + val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector)) newConnection } val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) @@ -465,7 +465,7 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) + val g = Await.result(f, 10 second) if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis @@ -473,6 +473,7 @@ private[spark] object ConnectionManager { val mb = size * count / 1024.0 / 1024.0 val ms = finishTime - startTime val tput = mb * 1000.0 / ms + println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") println("--------------------------") println() } diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 47ceaf3c07..0e79c518e0 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -13,8 +13,8 @@ import akka.util.duration._ private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: ConnectionManagerTest ") + if (args.length < 5) { + println("Usage: ConnectionManagerTest ") System.exit(1) } @@ -29,16 +29,16 @@ private[spark] object ConnectionManagerTest extends Logging{ /*println("Slaves")*/ /*slaves.foreach(println)*/ - - val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map( + val tasknum = args(2).toInt + val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( i => SparkEnv.get.connectionManager.id).collect() println("\nSlave ConnectionManagerIds") slaveConnManagerIds.foreach(println) println - val count = 10 + val count = args(4).toInt (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => { + val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { val connManager = SparkEnv.get.connectionManager val thisConnManagerId = connManager.id connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { @@ -46,7 +46,7 @@ private[spark] object ConnectionManagerTest extends Logging{ None }) - val size = 100 * 1024 * 1024 + val size = (args(3).toInt) * 1024 * 1024 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -56,13 +56,13 @@ private[spark] object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => Await.result(f, 1.second)) + val results = futures.map(f => Await.result(f, 999.second)) val finishTime = System.currentTimeMillis Thread.sleep(5000) val mb = size * results.size / 1024.0 / 1024.0 val ms = finishTime - startTime - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" + val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" logInfo(resultStr) resultStr }).collect() -- cgit v1.2.3 From 8ac0f35be42765fcd6f02dcf0f070f2ef2377a85 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 8 Jan 2013 09:57:45 -0600 Subject: Add JavaRDDLike.keyBy. --- core/src/main/scala/spark/api/java/JavaRDDLike.scala | 8 ++++++++ core/src/test/scala/spark/JavaAPISuite.java | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 81d3a94466..d15f6dd02f 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -298,4 +298,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Save this RDD as a SequenceFile of serialized objects. */ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) + + /** + * Creates tuples of the elements in this RDD by applying `f`. + */ + def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { + implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + JavaPairRDD.fromRDD(rdd.keyBy(f)) + } } diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 0817d1146c..c61913fc82 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -629,4 +629,16 @@ public class JavaAPISuite implements Serializable { floatAccum.setValue(5.0f); Assert.assertEquals((Float) 5.0f, floatAccum.value()); } + + @Test + public void keyBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); + List> s = rdd.keyBy(new Function() { + public String call(Integer t) throws Exception { + return t.toString(); + } + }).collect(); + Assert.assertEquals(new Tuple2("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + } } -- cgit v1.2.3 From c3f1675f9c4a1be9eebf9512795abc968ac29ba2 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 8 Jan 2013 14:44:33 -0600 Subject: Retrieve jars to a flat directory so * can be used for the classpath. --- project/SparkBuild.scala | 1 + run | 12 +++--------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7c7c33131a..518c4130f0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -38,6 +38,7 @@ object SparkBuild extends Build { scalacOptions := Seq(/*"-deprecation",*/ "-unchecked", "-optimize"), // -deprecation is too noisy due to usage of old Hadoop API, enable it once that's no longer an issue unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, retrieveManaged := true, + retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", transitiveClassifiers in Scope.GlobalScope := Seq("sources"), testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), diff --git a/run b/run index 1528f83534..6cfe9631af 100755 --- a/run +++ b/run @@ -75,16 +75,10 @@ CLASSPATH+=":$CORE_DIR/src/main/resources" CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" if [ -e "$FWDIR/lib_managed" ]; then - for jar in `find "$FWDIR/lib_managed/jars" -name '*jar'`; do - CLASSPATH+=":$jar" - done - for jar in `find "$FWDIR/lib_managed/bundles" -name '*jar'`; do - CLASSPATH+=":$jar" - done + CLASSPATH+=":$FWDIR/lib_managed/jars/*" + CLASSPATH+=":$FWDIR/lib_managed/bundles/*" fi -for jar in `find "$REPL_DIR/lib" -name '*jar'`; do - CLASSPATH+=":$jar" -done +CLASSPATH+=":$REPL_DIR/lib/*" for jar in `find "$REPL_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do CLASSPATH+=":$jar" done -- cgit v1.2.3 From b57dd0f16024a82dfc223e69528b9908b931f068 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 8 Jan 2013 16:04:41 -0800 Subject: Add mapPartitionsWithSplit() to PySpark. --- .../main/scala/spark/api/python/PythonRDD.scala | 5 ++++ docs/python-programming-guide.md | 1 - python/pyspark/rdd.py | 33 ++++++++++++++-------- python/pyspark/worker.py | 4 ++- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 79d824d494..f431ef28d3 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -65,6 +65,9 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream) + // Split index + dOut.writeInt(split.index) + // Broadcast variables dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { dOut.writeLong(broadcast.id) @@ -72,10 +75,12 @@ private[spark] class PythonRDD[T: ClassManifest]( dOut.write(broadcast.value) dOut.flush() } + // Serialized user code for (elem <- command) { out.println(elem) } out.flush() + // Data values for (elem <- parent.iterator(split, context)) { PythonRDD.writeAsPickle(elem, dOut) } diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index d963551296..78ef310a00 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -19,7 +19,6 @@ There are a few key differences between the Python and Scala APIs: - Accumulators - Special functions on RDDs of doubles, such as `mean` and `stdev` - `lookup` - - `mapPartitionsWithSplit` - `persist` at storage levels other than `MEMORY_ONLY` - `sample` - `sort` diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4ba417b2a2..1d36da42b0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -55,7 +55,7 @@ class RDD(object): """ Return a new RDD containing the distinct elements in this RDD. """ - def func(iterator): return imap(f, iterator) + def func(split, iterator): return imap(f, iterator) return PipelinedRDD(self, func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -69,8 +69,8 @@ class RDD(object): >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - def func(iterator): return chain.from_iterable(imap(f, iterator)) - return self.mapPartitions(func, preservesPartitioning) + def func(s, iterator): return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithSplit(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): """ @@ -81,9 +81,20 @@ class RDD(object): >>> rdd.mapPartitions(f).collect() [3, 7] """ - return PipelinedRDD(self, f, preservesPartitioning) + def func(s, iterator): return f(iterator) + return self.mapPartitionsWithSplit(func) + + def mapPartitionsWithSplit(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD, + while tracking the index of the original partition. - # TODO: mapPartitionsWithSplit + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(splitIndex, iterator): yield splitIndex + >>> rdd.mapPartitionsWithSplit(f).sum() + 6 + """ + return PipelinedRDD(self, f, preservesPartitioning) def filter(self, f): """ @@ -362,7 +373,7 @@ class RDD(object): >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' """ - def func(iterator): + def func(split, iterator): return (str(x).encode("utf-8") for x in iterator) keyed = PipelinedRDD(self, func) keyed._bypass_serializer = True @@ -500,7 +511,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numSplits) objects # to Java. Each object is a (splitNumber, [objects]) pair. - def add_shuffle_key(iterator): + def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: buckets[hashFunc(k) % numSplits].append((k, v)) @@ -653,8 +664,8 @@ class PipelinedRDD(RDD): def __init__(self, prev, func, preservesPartitioning=False): if isinstance(prev, PipelinedRDD) and not prev.is_cached: prev_func = prev.func - def pipeline_func(iterator): - return func(prev_func(iterator)) + def pipeline_func(split, iterator): + return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning @@ -677,8 +688,8 @@ class PipelinedRDD(RDD): if not self._bypass_serializer and self.ctx.batchSize != 1: oldfunc = self.func batchSize = self.ctx.batchSize - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) + def batched_func(split, iterator): + return batched(oldfunc(split, iterator), batchSize) func = batched_func cmds = [func, self._bypass_serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 9f6b507dbd..3d792bbaa2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -21,6 +21,7 @@ def load_obj(): def main(): + split_index = read_int(sys.stdin) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) @@ -32,7 +33,8 @@ def main(): dumps = lambda x: x else: dumps = dump_pickle - for obj in func(read_from_pickle_file(sys.stdin)): + iterator = read_from_pickle_file(sys.stdin) + for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) -- cgit v1.2.3 From 9cc764f52323baa3a218ce9e301d3cc98f1e8b20 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 8 Jan 2013 22:29:57 -0800 Subject: Code style --- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index a842afcdeb..a089b71644 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,8 +201,11 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" - else "non-preferred, not one of " + task.preferredLocations.mkString(", ") + val prefStr = if (preferred) { + "preferred" + } else { + "non-preferred, not one of " + task.preferredLocations.mkString(", ") + } logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping -- cgit v1.2.3 From 6e8c8f61c478ec5829677a38a624f17ac9609f74 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 10:35:23 -0500 Subject: Added the spray implicit marshaller library Added the io.spray JSON library --- project/SparkBuild.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2f67bb9921..f2b79d9ed8 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -133,6 +133,8 @@ object SparkBuild extends Build { "colt" % "colt" % "1.2.0", "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", + "cc.spray" %% "spray-json" % "1.1.1", + "io.spray" %% "spray-json" % "1.2.3", "org.apache.mesos" % "mesos" % "0.9.0-incubating" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } -- cgit v1.2.3 From 269fe018c73a0d4e12a3c881dbd3bd807e504891 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 10:35:59 -0500 Subject: JSON object definitions --- .../src/main/scala/spark/deploy/JsonProtocol.scala | 59 ++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 core/src/main/scala/spark/deploy/JsonProtocol.scala diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala new file mode 100644 index 0000000000..dc7da85f9c --- /dev/null +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -0,0 +1,59 @@ +package spark.deploy + +import master.{JobInfo, WorkerInfo} +import spray.json._ + +/** + * spray-json helper class containing implicit conversion to json for marshalling responses + */ +private[spark] object JsonProtocol extends DefaultJsonProtocol { + import cc.spray.json._ + + implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] { + def write(obj: WorkerInfo) = JsObject( + "id" -> JsString(obj.id), + "host" -> JsString(obj.host), + "webuiaddress" -> JsString(obj.webUiAddress), + "cores" -> JsNumber(obj.cores), + "coresused" -> JsNumber(obj.coresUsed), + "memory" -> JsNumber(obj.memory), + "memoryused" -> JsNumber(obj.memoryUsed) + ) + } + + implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] { + def write(obj: JobInfo) = JsObject( + "starttime" -> JsNumber(obj.startTime), + "id" -> JsString(obj.id), + "name" -> JsString(obj.desc.name), + "cores" -> JsNumber(obj.desc.cores), + "user" -> JsString(obj.desc.user), + "memoryperslave" -> JsNumber(obj.desc.memoryPerSlave), + "submitdate" -> JsString(obj.submitDate.toString)) + } + + implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] { + def write(obj: MasterState) = JsObject( + "url" -> JsString("spark://" + obj.uri), + "workers" -> JsArray(obj.workers.toList.map(_.toJson)), + "cores" -> JsNumber(obj.workers.map(_.cores).sum), + "coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum), + "memory" -> JsNumber(obj.workers.map(_.memory).sum), + "memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum), + "activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)), + "completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson)) + ) + } + + implicit object WorkerStateJsonFormat extends RootJsonWriter[WorkerState] { + def write(obj: WorkerState) = JsObject( + "id" -> JsString(obj.workerId), + "masterurl" -> JsString(obj.masterUrl), + "masterwebuiurl" -> JsString(obj.masterWebUiUrl), + "cores" -> JsNumber(obj.cores), + "coresused" -> JsNumber(obj.coresUsed), + "memory" -> JsNumber(obj.memory), + "memoryused" -> JsNumber(obj.memoryUsed) + ) + } +} -- cgit v1.2.3 From 0da2ff102e1e8ac50059252a153a1b9b3e74b6b8 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 10:36:56 -0500 Subject: Added url query parameter json and handler --- .../main/scala/spark/deploy/master/MasterWebUI.scala | 19 ++++++++++++++----- .../main/scala/spark/deploy/worker/WorkerWebUI.scala | 20 +++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 3cdd3721f5..dfec1d1dc5 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -9,6 +9,9 @@ import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ import spark.deploy._ +import cc.spray.http.MediaTypes +import JsonProtocol._ +import cc.spray.typeconversion.SprayJsonSupport._ private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { @@ -19,13 +22,19 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val handler = { get { - path("") { - completeWith { + (path("") & parameters('json ?)) { + case Some(js) => val future = master ? RequestMasterState - future.map { - masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(future.mapTo[MasterState]) + } + case None => + completeWith { + val future = master ? RequestMasterState + future.map { + masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) + } } - } } ~ path("job") { parameter("jobId") { jobId => diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index d06f4884ee..a168f54ca0 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -7,7 +7,10 @@ import akka.util.Timeout import akka.util.duration._ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ -import spark.deploy.{WorkerState, RequestWorkerState} +import spark.deploy.{JsonProtocol, WorkerState, RequestWorkerState} +import cc.spray.http.MediaTypes +import JsonProtocol._ +import cc.spray.typeconversion.SprayJsonSupport._ private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { @@ -18,13 +21,20 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct val handler = { get { - path("") { - completeWith{ + (path("") & parameters('json ?)) { + case Some(js) => { val future = worker ? RequestWorkerState - future.map { workerState => - spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(future.mapTo[WorkerState]) } } + case None => + completeWith{ + val future = worker ? RequestWorkerState + future.map { workerState => + spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) + } + } } ~ path("log") { parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) => -- cgit v1.2.3 From bf9d9946f97782c9212420123b4a042918d7df5e Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 11:29:22 -0500 Subject: Query parameter reformatted to be more extensible and routing more robust --- core/src/main/scala/spark/deploy/master/MasterWebUI.scala | 6 +++--- core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index dfec1d1dc5..a96b55d6f3 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -22,13 +22,13 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val handler = { get { - (path("") & parameters('json ?)) { - case Some(js) => + (path("") & parameters('format ?)) { + case Some(js) if js.equalsIgnoreCase("json") => val future = master ? RequestMasterState respondWithMediaType(MediaTypes.`application/json`) { ctx => ctx.complete(future.mapTo[MasterState]) } - case None => + case _ => completeWith { val future = master ? RequestMasterState future.map { diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index a168f54ca0..84b6c16bd6 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -21,14 +21,14 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct val handler = { get { - (path("") & parameters('json ?)) { - case Some(js) => { + (path("") & parameters('format ?)) { + case Some(js) if js.equalsIgnoreCase("json") => { val future = worker ? RequestWorkerState respondWithMediaType(MediaTypes.`application/json`) { ctx => ctx.complete(future.mapTo[WorkerState]) } } - case None => + case _ => completeWith{ val future = worker ? RequestWorkerState future.map { workerState => -- cgit v1.2.3 From 549ee388a125ac7014ae3dadfb16c582e250c654 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 15:12:23 -0500 Subject: Removed io.spray spray-json dependency as it is not needed. --- core/src/main/scala/spark/deploy/JsonProtocol.scala | 4 +--- project/SparkBuild.scala | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index dc7da85f9c..f14f804b3a 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -1,14 +1,12 @@ package spark.deploy import master.{JobInfo, WorkerInfo} -import spray.json._ +import cc.spray.json._ /** * spray-json helper class containing implicit conversion to json for marshalling responses */ private[spark] object JsonProtocol extends DefaultJsonProtocol { - import cc.spray.json._ - implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] { def write(obj: WorkerInfo) = JsObject( "id" -> JsString(obj.id), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f2b79d9ed8..c63efbdd2a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -134,7 +134,6 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", - "io.spray" %% "spray-json" % "1.2.3", "org.apache.mesos" % "mesos" % "0.9.0-incubating" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } -- cgit v1.2.3 From 156e8b47ef24cd1a54ee9f1141a91c20e26ac037 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 9 Jan 2013 12:42:10 -0800 Subject: Split Time to Time (absolute instant of time) and Duration (duration of time). --- .../main/scala/spark/streaming/Checkpoint.scala | 2 +- .../src/main/scala/spark/streaming/DStream.scala | 26 ++++----- .../main/scala/spark/streaming/DStreamGraph.scala | 8 +-- .../src/main/scala/spark/streaming/Duration.scala | 62 ++++++++++++++++++++++ .../src/main/scala/spark/streaming/Interval.scala | 30 ++++------- .../spark/streaming/PairDStreamFunctions.scala | 40 +++++++------- .../src/main/scala/spark/streaming/Scheduler.scala | 5 +- .../scala/spark/streaming/StreamingContext.scala | 12 ++--- .../src/main/scala/spark/streaming/Time.scala | 55 +++++-------------- .../spark/streaming/dstream/CoGroupedDStream.scala | 4 +- .../spark/streaming/dstream/FilteredDStream.scala | 4 +- .../streaming/dstream/FlatMapValuedDStream.scala | 4 +- .../streaming/dstream/FlatMappedDStream.scala | 4 +- .../spark/streaming/dstream/ForEachDStream.scala | 4 +- .../spark/streaming/dstream/GlommedDStream.scala | 4 +- .../spark/streaming/dstream/InputDStream.scala | 4 +- .../streaming/dstream/MapPartitionedDStream.scala | 4 +- .../spark/streaming/dstream/MapValuedDStream.scala | 4 +- .../spark/streaming/dstream/MappedDStream.scala | 4 +- .../streaming/dstream/ReducedWindowedDStream.scala | 16 +++--- .../spark/streaming/dstream/ShuffledDStream.scala | 4 +- .../spark/streaming/dstream/StateDStream.scala | 4 +- .../streaming/dstream/TransformedDStream.scala | 4 +- .../spark/streaming/dstream/UnionDStream.scala | 4 +- .../spark/streaming/dstream/WindowedDStream.scala | 14 ++--- 25 files changed, 174 insertions(+), 152 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/Duration.scala diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 11a7232d7b..a9c6e65d62 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -17,7 +17,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val jars = ssc.sc.jars val graph = ssc.graph val checkpointDir = ssc.checkpointDir - val checkpointInterval = ssc.checkpointInterval + val checkpointInterval: Duration = ssc.checkpointInterval def validate() { assert(master != null, "Checkpoint.master is null") diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index beba9cfd4f..7611598fde 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -2,7 +2,7 @@ package spark.streaming import spark.streaming.dstream._ import StreamingContext._ -import Time._ +//import Time._ import spark.{RDD, Logging} import spark.storage.StorageLevel @@ -47,7 +47,7 @@ abstract class DStream[T: ClassManifest] ( // ======================================================================= /** Time interval after which the DStream generates a RDD */ - def slideTime: Time + def slideTime: Duration /** List of parent DStreams on which this DStream depends on */ def dependencies: List[DStream[_]] @@ -67,14 +67,14 @@ abstract class DStream[T: ClassManifest] ( protected[streaming] var zeroTime: Time = null // Duration for which the DStream will remember each RDD created - protected[streaming] var rememberDuration: Time = null + protected[streaming] var rememberDuration: Duration = null // Storage level of the RDDs in the stream protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE // Checkpoint details protected[streaming] val mustCheckpoint = false - protected[streaming] var checkpointInterval: Time = null + protected[streaming] var checkpointInterval: Duration = null protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]()) // Reference to whole DStream graph @@ -108,7 +108,7 @@ abstract class DStream[T: ClassManifest] ( * Enable periodic checkpointing of RDDs of this DStream * @param interval Time interval after which generated RDD will be checkpointed */ - def checkpoint(interval: Time): DStream[T] = { + def checkpoint(interval: Duration): DStream[T] = { if (isInitialized) { throw new UnsupportedOperationException( "Cannot change checkpoint interval of an DStream after streaming context has started") @@ -224,7 +224,7 @@ abstract class DStream[T: ClassManifest] ( dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def remember(duration: Time) { + protected[streaming] def remember(duration: Duration) { if (duration != null && duration > rememberDuration) { rememberDuration = duration logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) @@ -531,7 +531,7 @@ abstract class DStream[T: ClassManifest] ( * @param windowTime width of the window; must be a multiple of this DStream's interval. * @return */ - def window(windowTime: Time): DStream[T] = window(windowTime, this.slideTime) + def window(windowTime: Duration): DStream[T] = window(windowTime, this.slideTime) /** * Return a new DStream which is computed based on windowed batches of this DStream. @@ -541,7 +541,7 @@ abstract class DStream[T: ClassManifest] ( * the new DStream will generate RDDs); must be a multiple of this * DStream's interval */ - def window(windowTime: Time, slideTime: Time): DStream[T] = { + def window(windowTime: Duration, slideTime: Duration): DStream[T] = { new WindowedDStream(this, windowTime, slideTime) } @@ -550,22 +550,22 @@ abstract class DStream[T: ClassManifest] ( * This is equivalent to window(batchTime, batchTime). * @param batchTime tumbling window duration; must be a multiple of this DStream's interval */ - def tumble(batchTime: Time): DStream[T] = window(batchTime, batchTime) + def tumble(batchTime: Duration): DStream[T] = window(batchTime, batchTime) /** * Returns a new DStream in which each RDD has a single element generated by reducing all * elements in a window over this DStream. windowTime and slideTime are as defined in the * window() operation. This is equivalent to window(windowTime, slideTime).reduce(reduceFunc) */ - def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time): DStream[T] = { + def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Duration, slideTime: Duration): DStream[T] = { this.window(windowTime, slideTime).reduce(reduceFunc) } def reduceByWindow( reduceFunc: (T, T) => T, invReduceFunc: (T, T) => T, - windowTime: Time, - slideTime: Time + windowTime: Duration, + slideTime: Duration ): DStream[T] = { this.map(x => (1, x)) .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1) @@ -577,7 +577,7 @@ abstract class DStream[T: ClassManifest] ( * of elements in a window over this DStream. windowTime and slideTime are as defined in the * window() operation. This is equivalent to window(windowTime, slideTime).count() */ - def countByWindow(windowTime: Time, slideTime: Time): DStream[Int] = { + def countByWindow(windowTime: Duration, slideTime: Duration): DStream[Int] = { this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowTime, slideTime) } diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index c72429370e..bc4a40d7bc 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -12,8 +12,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { private val outputStreams = new ArrayBuffer[DStream[_]]() private[streaming] var zeroTime: Time = null - private[streaming] var batchDuration: Time = null - private[streaming] var rememberDuration: Time = null + private[streaming] var batchDuration: Duration = null + private[streaming] var rememberDuration: Duration = null private[streaming] var checkpointInProgress = false private[streaming] def start(time: Time) { @@ -41,7 +41,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } - private[streaming] def setBatchDuration(duration: Time) { + private[streaming] def setBatchDuration(duration: Duration) { this.synchronized { if (batchDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -51,7 +51,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { batchDuration = duration } - private[streaming] def remember(duration: Time) { + private[streaming] def remember(duration: Duration) { this.synchronized { if (rememberDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + diff --git a/streaming/src/main/scala/spark/streaming/Duration.scala b/streaming/src/main/scala/spark/streaming/Duration.scala new file mode 100644 index 0000000000..d2728d9dca --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/Duration.scala @@ -0,0 +1,62 @@ +package spark.streaming + +class Duration (private val millis: Long) { + + def < (that: Duration): Boolean = (this.millis < that.millis) + + def <= (that: Duration): Boolean = (this.millis <= that.millis) + + def > (that: Duration): Boolean = (this.millis > that.millis) + + def >= (that: Duration): Boolean = (this.millis >= that.millis) + + def + (that: Duration): Duration = new Duration(millis + that.millis) + + def - (that: Duration): Duration = new Duration(millis - that.millis) + + def * (times: Int): Duration = new Duration(millis * times) + + def / (that: Duration): Long = millis / that.millis + + def isMultipleOf(that: Duration): Boolean = + (this.millis % that.millis == 0) + + def min(that: Duration): Duration = if (this < that) this else that + + def max(that: Duration): Duration = if (this > that) this else that + + def isZero: Boolean = (this.millis == 0) + + override def toString: String = (millis.toString + " ms") + + def toFormattedString: String = millis.toString + + def milliseconds: Long = millis +} + + +/** + * Helper object that creates instance of [[spark.streaming.Duration]] representing + * a given number of milliseconds. + */ +object Milliseconds { + def apply(milliseconds: Long) = new Duration(milliseconds) +} + +/** + * Helper object that creates instance of [[spark.streaming.Duration]] representing + * a given number of seconds. + */ +object Seconds { + def apply(seconds: Long) = new Duration(seconds * 1000) +} + +/** + * Helper object that creates instance of [[spark.streaming.Duration]] representing + * a given number of minutes. + */ +object Minutes { + def apply(minutes: Long) = new Duration(minutes * 60000) +} + + diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index fa0b7ce19d..dc21dfb722 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -1,16 +1,16 @@ package spark.streaming private[streaming] -case class Interval(beginTime: Time, endTime: Time) { - def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs)) +class Interval(val beginTime: Time, val endTime: Time) { + def this(beginMs: Long, endMs: Long) = this(new Time(beginMs), new Time(endMs)) - def duration(): Time = endTime - beginTime + def duration(): Duration = endTime - beginTime - def + (time: Time): Interval = { + def + (time: Duration): Interval = { new Interval(beginTime + time, endTime + time) } - def - (time: Time): Interval = { + def - (time: Duration): Interval = { new Interval(beginTime - time, endTime - time) } @@ -27,24 +27,14 @@ case class Interval(beginTime: Time, endTime: Time) { def >= (that: Interval) = !(this < that) - def next(): Interval = { - this + (endTime - beginTime) - } - - def isZero = (beginTime.isZero && endTime.isZero) - - def toFormattedString = beginTime.toFormattedString + "-" + endTime.toFormattedString - - override def toString = "[" + beginTime + ", " + endTime + "]" + override def toString = "[" + beginTime + ", " + endTime + "]" } object Interval { - def zero() = new Interval (Time.zero, Time.zero) - - def currentInterval(intervalDuration: Time): Interval = { - val time = Time(System.currentTimeMillis) - val intervalBegin = time.floor(intervalDuration) - Interval(intervalBegin, intervalBegin + intervalDuration) + def currentInterval(duration: Duration): Interval = { + val time = new Time(System.currentTimeMillis) + val intervalBegin = time.floor(duration) + new Interval(intervalBegin, intervalBegin + duration) } } diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index b0a208e67f..dd64064138 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -69,21 +69,21 @@ extends Serializable { self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) } - def groupByKeyAndWindow(windowTime: Time, slideTime: Time): DStream[(K, Seq[V])] = { + def groupByKeyAndWindow(windowTime: Duration, slideTime: Duration): DStream[(K, Seq[V])] = { groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner()) } def groupByKeyAndWindow( - windowTime: Time, - slideTime: Time, + windowTime: Duration, + slideTime: Duration, numPartitions: Int ): DStream[(K, Seq[V])] = { groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner(numPartitions)) } def groupByKeyAndWindow( - windowTime: Time, - slideTime: Time, + windowTime: Duration, + slideTime: Duration, partitioner: Partitioner ): DStream[(K, Seq[V])] = { self.window(windowTime, slideTime).groupByKey(partitioner) @@ -91,23 +91,23 @@ extends Serializable { def reduceByKeyAndWindow( reduceFunc: (V, V) => V, - windowTime: Time + windowTime: Duration ): DStream[(K, V)] = { reduceByKeyAndWindow(reduceFunc, windowTime, self.slideTime, defaultPartitioner()) } def reduceByKeyAndWindow( reduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time + windowTime: Duration, + slideTime: Duration ): DStream[(K, V)] = { reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner()) } def reduceByKeyAndWindow( reduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, + windowTime: Duration, + slideTime: Duration, numPartitions: Int ): DStream[(K, V)] = { reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions)) @@ -115,8 +115,8 @@ extends Serializable { def reduceByKeyAndWindow( reduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, + windowTime: Duration, + slideTime: Duration, partitioner: Partitioner ): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) @@ -134,8 +134,8 @@ extends Serializable { def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time + windowTime: Duration, + slideTime: Duration ): DStream[(K, V)] = { reduceByKeyAndWindow( @@ -145,8 +145,8 @@ extends Serializable { def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, + windowTime: Duration, + slideTime: Duration, numPartitions: Int ): DStream[(K, V)] = { @@ -157,8 +157,8 @@ extends Serializable { def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, - windowTime: Time, - slideTime: Time, + windowTime: Duration, + slideTime: Duration, partitioner: Partitioner ): DStream[(K, V)] = { @@ -169,8 +169,8 @@ extends Serializable { } def countByKeyAndWindow( - windowTime: Time, - slideTime: Time, + windowTime: Duration, + slideTime: Duration, numPartitions: Int = self.ssc.sc.defaultParallelism ): DStream[(K, Long)] = { diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index eb40affe6d..10845e3a5e 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -22,7 +22,8 @@ class Scheduler(ssc: StreamingContext) extends Logging { val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] - val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, generateRDDs(_)) + val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, + longTime => generateRDDs(new Time(longTime))) def start() { // If context was started from checkpoint, then restart timer such that @@ -41,7 +42,7 @@ class Scheduler(ssc: StreamingContext) extends Logging { timer.restart(graph.zeroTime.milliseconds) logInfo("Scheduler's timer restarted") } else { - val firstTime = Time(timer.start()) + val firstTime = new Time(timer.start()) graph.start(firstTime - ssc.graph.batchDuration) logInfo("Scheduler's timer started") } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 215246ba2e..ee8314df3f 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -26,7 +26,7 @@ import java.util.UUID class StreamingContext private ( sc_ : SparkContext, cp_ : Checkpoint, - batchDur_ : Time + batchDur_ : Duration ) extends Logging { /** @@ -34,7 +34,7 @@ class StreamingContext private ( * @param sparkContext Existing SparkContext * @param batchDuration The time interval at which streaming data will be divided into batches */ - def this(sparkContext: SparkContext, batchDuration: Time) = this(sparkContext, null, batchDuration) + def this(sparkContext: SparkContext, batchDuration: Duration) = this(sparkContext, null, batchDuration) /** * Creates a StreamingContext by providing the details necessary for creating a new SparkContext. @@ -42,7 +42,7 @@ class StreamingContext private ( * @param frameworkName A name for your job, to display on the cluster web UI * @param batchDuration The time interval at which streaming data will be divided into batches */ - def this(master: String, frameworkName: String, batchDuration: Time) = + def this(master: String, frameworkName: String, batchDuration: Duration) = this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration) /** @@ -96,7 +96,7 @@ class StreamingContext private ( } } - protected[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null + protected[streaming] var checkpointInterval: Duration = if (isCheckpointPresent) cp_.checkpointInterval else null protected[streaming] var receiverJobThread: Thread = null protected[streaming] var scheduler: Scheduler = null @@ -107,7 +107,7 @@ class StreamingContext private ( * if the developer wishes to query old data outside the DStream computation). * @param duration Minimum duration that each DStream should remember its RDDs */ - def remember(duration: Time) { + def remember(duration: Duration) { graph.remember(duration) } @@ -117,7 +117,7 @@ class StreamingContext private ( * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored * @param interval checkpoint interval */ - def checkpoint(directory: String, interval: Time = null) { + def checkpoint(directory: String, interval: Duration = null) { if (directory != null) { sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory)) checkpointDir = directory diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 3c6fd5d967..069df82e52 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -7,7 +7,7 @@ package spark.streaming * @param millis Time in UTC. */ -case class Time(private val millis: Long) { +class Time(private val millis: Long) { def < (that: Time): Boolean = (this.millis < that.millis) @@ -17,63 +17,32 @@ case class Time(private val millis: Long) { def >= (that: Time): Boolean = (this.millis >= that.millis) - def + (that: Time): Time = Time(millis + that.millis) - - def - (that: Time): Time = Time(millis - that.millis) - - def * (times: Int): Time = Time(millis * times) + def + (that: Duration): Time = new Time(millis + that.milliseconds) + + def - (that: Time): Duration = new Duration(millis - that.millis) - def / (that: Time): Long = millis / that.millis + def - (that: Duration): Time = new Time(millis - that.milliseconds) - def floor(that: Time): Time = { - val t = that.millis + def floor(that: Duration): Time = { + val t = that.milliseconds val m = math.floor(this.millis / t).toLong - Time(m * t) + new Time(m * t) } - def isMultipleOf(that: Time): Boolean = - (this.millis % that.millis == 0) + def isMultipleOf(that: Duration): Boolean = + (this.millis % that.milliseconds == 0) def min(that: Time): Time = if (this < that) this else that def max(that: Time): Time = if (this > that) this else that - def isZero: Boolean = (this.millis == 0) - override def toString: String = (millis.toString + " ms") - def toFormattedString: String = millis.toString - def milliseconds: Long = millis } -private[streaming] object Time { - val zero = Time(0) - +/*private[streaming] object Time { implicit def toTime(long: Long) = Time(long) } - -/** - * Helper object that creates instance of [[spark.streaming.Time]] representing - * a given number of milliseconds. - */ -object Milliseconds { - def apply(milliseconds: Long) = Time(milliseconds) -} - -/** - * Helper object that creates instance of [[spark.streaming.Time]] representing - * a given number of seconds. - */ -object Seconds { - def apply(seconds: Long) = Time(seconds * 1000) -} - -/** - * Helper object that creates instance of [[spark.streaming.Time]] representing - * a given number of minutes. - */ -object Minutes { - def apply(minutes: Long) = Time(minutes * 60000) -} +*/ diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala index bc23d423d3..ca178fd384 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala @@ -2,7 +2,7 @@ package spark.streaming.dstream import spark.{RDD, Partitioner} import spark.rdd.CoGroupedRDD -import spark.streaming.{Time, DStream} +import spark.streaming.{Time, DStream, Duration} private[streaming] class CoGroupedDStream[K : ClassManifest]( @@ -24,7 +24,7 @@ class CoGroupedDStream[K : ClassManifest]( override def dependencies = parents.toList - override def slideTime = parents.head.slideTime + override def slideTime: Duration = parents.head.slideTime override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { val part = partitioner diff --git a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala index 1cbb4d536e..76b9e58029 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} import spark.RDD private[streaming] @@ -11,7 +11,7 @@ class FilteredDStream[T: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[T]] = { parent.getOrCompute(validTime).map(_.filter(filterFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala index 11ed8cf317..28e9a456ac 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} import spark.RDD import spark.SparkContext._ @@ -12,7 +12,7 @@ class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest] override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[(K, U)]] = { parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala index a13b4c9ff9..ef305b66f1 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} import spark.RDD private[streaming] @@ -11,7 +11,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[U]] = { parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala index 41c629a225..f8af0a38a7 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala @@ -1,7 +1,7 @@ package spark.streaming.dstream import spark.RDD -import spark.streaming.{DStream, Job, Time} +import spark.streaming.{Duration, DStream, Job, Time} private[streaming] class ForEachDStream[T: ClassManifest] ( @@ -11,7 +11,7 @@ class ForEachDStream[T: ClassManifest] ( override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[Unit]] = None diff --git a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala index 92ea503cae..19cccea735 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} import spark.RDD private[streaming] @@ -9,7 +9,7 @@ class GlommedDStream[T: ClassManifest](parent: DStream[T]) override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[Array[T]]] = { parent.getOrCompute(validTime).map(_.glom()) diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala index 4959c66b06..50f0f45796 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala @@ -1,13 +1,13 @@ package spark.streaming.dstream -import spark.streaming.{StreamingContext, DStream} +import spark.streaming.{Duration, StreamingContext, DStream} abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) extends DStream[T](ssc_) { override def dependencies = List() - override def slideTime = { + override def slideTime: Duration = { if (ssc == null) throw new Exception("ssc is null") if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") ssc.graph.batchDuration diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala index daf78c6893..e9ca668aa6 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} import spark.RDD private[streaming] @@ -12,7 +12,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[U]] = { parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala index 689caeef0e..ebc7d0698b 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} import spark.RDD import spark.SparkContext._ @@ -12,7 +12,7 @@ class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[(K, U)]] = { parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala index 786b9966f2..3af8e7ab88 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} import spark.RDD private[streaming] @@ -11,7 +11,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] ( override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[U]] = { parent.getOrCompute(validTime).map(_.map[U](mapFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala index d289ed2a3f..a685a778ce 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -9,15 +9,15 @@ import spark.SparkContext._ import spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer -import spark.streaming.{Interval, Time, DStream} +import spark.streaming.{Duration, Interval, Time, DStream} private[streaming] class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( parent: DStream[(K, V)], reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, - _windowTime: Time, - _slideTime: Time, + _windowTime: Duration, + _slideTime: Duration, partitioner: Partitioner ) extends DStream[(K,V)](parent.ssc) { @@ -39,15 +39,15 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( super.persist(StorageLevel.MEMORY_ONLY_SER) reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - def windowTime: Time = _windowTime + def windowTime: Duration = _windowTime override def dependencies = List(reducedStream) - override def slideTime: Time = _slideTime + override def slideTime: Duration = _slideTime override val mustCheckpoint = true - override def parentRememberDuration: Time = rememberDuration + windowTime + override def parentRememberDuration: Duration = rememberDuration + windowTime override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { super.persist(storageLevel) @@ -55,7 +55,7 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( this } - override def checkpoint(interval: Time): DStream[(K, V)] = { + override def checkpoint(interval: Duration): DStream[(K, V)] = { super.checkpoint(interval) //reducedStream.checkpoint(interval) this @@ -66,7 +66,7 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( val invReduceF = invReduceFunc val currentTime = validTime - val currentWindow = Interval(currentTime - windowTime + parent.slideTime, currentTime) + val currentWindow = new Interval(currentTime - windowTime + parent.slideTime, currentTime) val previousWindow = currentWindow - slideTime logDebug("Window time = " + windowTime) diff --git a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala index 6854bbe665..7612804b96 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala @@ -2,7 +2,7 @@ package spark.streaming.dstream import spark.{RDD, Partitioner} import spark.SparkContext._ -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} private[streaming] class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( @@ -15,7 +15,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[(K,C)]] = { parent.getOrCompute(validTime) match { diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala index 175b3060c1..ce4f486825 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -4,7 +4,7 @@ import spark.RDD import spark.Partitioner import spark.SparkContext._ import spark.storage.StorageLevel -import spark.streaming.{Time, DStream} +import spark.streaming.{Duration, Time, DStream} private[streaming] class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( @@ -18,7 +18,7 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife override def dependencies = List(parent) - override def slideTime = parent.slideTime + override def slideTime: Duration = parent.slideTime override val mustCheckpoint = true diff --git a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala index 0337579514..5a2c5bc0f0 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala @@ -1,7 +1,7 @@ package spark.streaming.dstream import spark.RDD -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} private[streaming] class TransformedDStream[T: ClassManifest, U: ClassManifest] ( @@ -11,7 +11,7 @@ class TransformedDStream[T: ClassManifest, U: ClassManifest] ( override def dependencies = List(parent) - override def slideTime: Time = parent.slideTime + override def slideTime: Duration = parent.slideTime override def compute(validTime: Time): Option[RDD[U]] = { parent.getOrCompute(validTime).map(transformFunc(_, validTime)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala index 3bf4c2ecea..224a19842b 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.streaming.{DStream, Time} +import spark.streaming.{Duration, DStream, Time} import spark.RDD import collection.mutable.ArrayBuffer import spark.rdd.UnionRDD @@ -23,7 +23,7 @@ class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) override def dependencies = parents.toList - override def slideTime: Time = parents.head.slideTime + override def slideTime: Duration = parents.head.slideTime override def compute(validTime: Time): Option[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() diff --git a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala index 7718794cbf..45689b25ce 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala @@ -3,13 +3,13 @@ package spark.streaming.dstream import spark.RDD import spark.rdd.UnionRDD import spark.storage.StorageLevel -import spark.streaming.{Interval, Time, DStream} +import spark.streaming.{Duration, Interval, Time, DStream} private[streaming] class WindowedDStream[T: ClassManifest]( parent: DStream[T], - _windowTime: Time, - _slideTime: Time) + _windowTime: Duration, + _slideTime: Duration) extends DStream[T](parent.ssc) { if (!_windowTime.isMultipleOf(parent.slideTime)) @@ -22,16 +22,16 @@ class WindowedDStream[T: ClassManifest]( parent.persist(StorageLevel.MEMORY_ONLY_SER) - def windowTime: Time = _windowTime + def windowTime: Duration = _windowTime override def dependencies = List(parent) - override def slideTime: Time = _slideTime + override def slideTime: Duration = _slideTime - override def parentRememberDuration: Time = rememberDuration + windowTime + override def parentRememberDuration: Duration = rememberDuration + windowTime override def compute(validTime: Time): Option[RDD[T]] = { - val currentWindow = Interval(validTime - windowTime + parent.slideTime, validTime) + val currentWindow = new Interval(validTime - windowTime + parent.slideTime, validTime) Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) } } -- cgit v1.2.3 From 365506fb038a76ff3810957f5bc5823f5f16af40 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 9 Jan 2013 14:29:25 -0800 Subject: Changed variable name form ***Time to ***Duration to keep things consistent. --- .../main/scala/spark/streaming/Checkpoint.scala | 2 +- .../src/main/scala/spark/streaming/DStream.scala | 90 +++++++++++----------- .../src/main/scala/spark/streaming/Duration.scala | 2 +- .../spark/streaming/PairDStreamFunctions.scala | 70 ++++++++--------- .../src/main/scala/spark/streaming/Scheduler.scala | 4 +- .../scala/spark/streaming/StreamingContext.scala | 12 +-- .../src/main/scala/spark/streaming/Time.scala | 22 ++---- .../spark/streaming/dstream/CoGroupedDStream.scala | 4 +- .../spark/streaming/dstream/FilteredDStream.scala | 2 +- .../streaming/dstream/FlatMapValuedDStream.scala | 2 +- .../streaming/dstream/FlatMappedDStream.scala | 2 +- .../spark/streaming/dstream/ForEachDStream.scala | 2 +- .../spark/streaming/dstream/GlommedDStream.scala | 2 +- .../spark/streaming/dstream/InputDStream.scala | 2 +- .../streaming/dstream/MapPartitionedDStream.scala | 2 +- .../spark/streaming/dstream/MapValuedDStream.scala | 2 +- .../spark/streaming/dstream/MappedDStream.scala | 2 +- .../streaming/dstream/ReducedWindowedDStream.scala | 34 ++++---- .../spark/streaming/dstream/ShuffledDStream.scala | 2 +- .../spark/streaming/dstream/StateDStream.scala | 4 +- .../streaming/dstream/TransformedDStream.scala | 2 +- .../spark/streaming/dstream/UnionDStream.scala | 4 +- .../spark/streaming/dstream/WindowedDStream.scala | 24 +++--- .../spark/streaming/BasicOperationsSuite.scala | 18 ++--- .../test/scala/spark/streaming/TestSuiteBase.scala | 2 +- .../spark/streaming/WindowOperationsSuite.scala | 48 ++++++------ 26 files changed, 176 insertions(+), 186 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index a9c6e65d62..2f3adb39c2 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -17,7 +17,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val jars = ssc.sc.jars val graph = ssc.graph val checkpointDir = ssc.checkpointDir - val checkpointInterval: Duration = ssc.checkpointInterval + val checkpointDuration: Duration = ssc.checkpointDuration def validate() { assert(master != null, "Checkpoint.master is null") diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 7611598fde..c89fb7723e 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -47,7 +47,7 @@ abstract class DStream[T: ClassManifest] ( // ======================================================================= /** Time interval after which the DStream generates a RDD */ - def slideTime: Duration + def slideDuration: Duration /** List of parent DStreams on which this DStream depends on */ def dependencies: List[DStream[_]] @@ -74,7 +74,7 @@ abstract class DStream[T: ClassManifest] ( // Checkpoint details protected[streaming] val mustCheckpoint = false - protected[streaming] var checkpointInterval: Duration = null + protected[streaming] var checkpointDuration: Duration = null protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]()) // Reference to whole DStream graph @@ -114,7 +114,7 @@ abstract class DStream[T: ClassManifest] ( "Cannot change checkpoint interval of an DStream after streaming context has started") } persist() - checkpointInterval = interval + checkpointDuration = interval this } @@ -130,16 +130,16 @@ abstract class DStream[T: ClassManifest] ( } zeroTime = time - // Set the checkpoint interval to be slideTime or 10 seconds, which ever is larger - if (mustCheckpoint && checkpointInterval == null) { - checkpointInterval = slideTime.max(Seconds(10)) - logInfo("Checkpoint interval automatically set to " + checkpointInterval) + // Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger + if (mustCheckpoint && checkpointDuration == null) { + checkpointDuration = slideDuration.max(Seconds(10)) + logInfo("Checkpoint interval automatically set to " + checkpointDuration) } // Set the minimum value of the rememberDuration if not already set - var minRememberDuration = slideTime - if (checkpointInterval != null && minRememberDuration <= checkpointInterval) { - minRememberDuration = checkpointInterval * 2 // times 2 just to be sure that the latest checkpoint is not forgetten + var minRememberDuration = slideDuration + if (checkpointDuration != null && minRememberDuration <= checkpointDuration) { + minRememberDuration = checkpointDuration * 2 // times 2 just to be sure that the latest checkpoint is not forgetten } if (rememberDuration == null || rememberDuration < minRememberDuration) { rememberDuration = minRememberDuration @@ -153,37 +153,37 @@ abstract class DStream[T: ClassManifest] ( assert(rememberDuration != null, "Remember duration is set to null") assert( - !mustCheckpoint || checkpointInterval != null, + !mustCheckpoint || checkpointDuration != null, "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + " Please use DStream.checkpoint() to set the interval." ) assert( - checkpointInterval == null || checkpointInterval >= slideTime, + checkpointDuration == null || checkpointDuration >= slideDuration, "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointInterval + " which is lower than its slide time (" + slideTime + "). " + - "Please set it to at least " + slideTime + "." + checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " + + "Please set it to at least " + slideDuration + "." ) assert( - checkpointInterval == null || checkpointInterval.isMultipleOf(slideTime), + checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration), "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointInterval + " which not a multiple of its slide time (" + slideTime + "). " + - "Please set it to a multiple " + slideTime + "." + checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " + + "Please set it to a multiple " + slideDuration + "." ) assert( - checkpointInterval == null || storageLevel != StorageLevel.NONE, + checkpointDuration == null || storageLevel != StorageLevel.NONE, "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + "level has not been set to enable persisting. Please use DStream.persist() to set the " + "storage level to use memory for better checkpointing performance." ) assert( - checkpointInterval == null || rememberDuration > checkpointInterval, + checkpointDuration == null || rememberDuration > checkpointDuration, "The remember duration for " + this.getClass.getSimpleName + " has been set to " + rememberDuration + " which is not more than the checkpoint interval (" + - checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." + checkpointDuration + "). Please set it to higher than " + checkpointDuration + "." ) val metadataCleanerDelay = spark.util.MetadataCleaner.getDelaySeconds @@ -200,9 +200,9 @@ abstract class DStream[T: ClassManifest] ( dependencies.foreach(_.validate()) - logInfo("Slide time = " + slideTime) + logInfo("Slide time = " + slideDuration) logInfo("Storage level = " + storageLevel) - logInfo("Checkpoint interval = " + checkpointInterval) + logInfo("Checkpoint interval = " + checkpointDuration) logInfo("Remember duration = " + rememberDuration) logInfo("Initialized and validated " + this) } @@ -232,11 +232,11 @@ abstract class DStream[T: ClassManifest] ( dependencies.foreach(_.remember(parentRememberDuration)) } - /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ + /** This method checks whether the 'time' is valid wrt slideDuration for generating RDD */ protected def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this + " has not been initialized") - } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideTime)) { + } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { false } else { true @@ -266,7 +266,7 @@ abstract class DStream[T: ClassManifest] ( newRDD.persist(storageLevel) logInfo("Persisting RDD " + newRDD.id + " for time " + time + " to " + storageLevel + " at time " + time) } - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) { newRDD.checkpoint() logInfo("Marking RDD " + newRDD.id + " for time " + time + " for checkpointing at time " + time) } @@ -528,21 +528,21 @@ abstract class DStream[T: ClassManifest] ( /** * Return a new DStream which is computed based on windowed batches of this DStream. * The new DStream generates RDDs with the same interval as this DStream. - * @param windowTime width of the window; must be a multiple of this DStream's interval. + * @param windowDuration width of the window; must be a multiple of this DStream's interval. * @return */ - def window(windowTime: Duration): DStream[T] = window(windowTime, this.slideTime) + def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration) /** * Return a new DStream which is computed based on windowed batches of this DStream. - * @param windowTime duration (i.e., width) of the window; + * @param windowDuration duration (i.e., width) of the window; * must be a multiple of this DStream's interval - * @param slideTime sliding interval of the window (i.e., the interval after which + * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's interval */ - def window(windowTime: Duration, slideTime: Duration): DStream[T] = { - new WindowedDStream(this, windowTime, slideTime) + def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = { + new WindowedDStream(this, windowDuration, slideDuration) } /** @@ -554,36 +554,36 @@ abstract class DStream[T: ClassManifest] ( /** * Returns a new DStream in which each RDD has a single element generated by reducing all - * elements in a window over this DStream. windowTime and slideTime are as defined in the - * window() operation. This is equivalent to window(windowTime, slideTime).reduce(reduceFunc) + * elements in a window over this DStream. windowDuration and slideDuration are as defined in the + * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc) */ - def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Duration, slideTime: Duration): DStream[T] = { - this.window(windowTime, slideTime).reduce(reduceFunc) + def reduceByWindow(reduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration): DStream[T] = { + this.window(windowDuration, slideDuration).reduce(reduceFunc) } def reduceByWindow( reduceFunc: (T, T) => T, invReduceFunc: (T, T) => T, - windowTime: Duration, - slideTime: Duration + windowDuration: Duration, + slideDuration: Duration ): DStream[T] = { this.map(x => (1, x)) - .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, 1) + .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) .map(_._2) } /** * Returns a new DStream in which each RDD has a single element generated by counting the number - * of elements in a window over this DStream. windowTime and slideTime are as defined in the - * window() operation. This is equivalent to window(windowTime, slideTime).count() + * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the + * window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowTime: Duration, slideTime: Duration): DStream[Int] = { - this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowTime, slideTime) + def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Int] = { + this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) } /** * Returns a new DStream by unifying data of another DStream with this DStream. - * @param that Another DStream having the same interval (i.e., slideTime) as this DStream. + * @param that Another DStream having the same slideDuration as this DStream. */ def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) @@ -599,13 +599,13 @@ abstract class DStream[T: ClassManifest] ( */ def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() - var time = toTime.floor(slideTime) + var time = toTime.floor(slideDuration) while (time >= zeroTime && time >= fromTime) { getOrCompute(time) match { case Some(rdd) => rdds += rdd case None => //throw new Exception("Could not get RDD for time " + time) } - time -= slideTime + time -= slideDuration } rdds.toSeq } diff --git a/streaming/src/main/scala/spark/streaming/Duration.scala b/streaming/src/main/scala/spark/streaming/Duration.scala index d2728d9dca..e4dc579a17 100644 --- a/streaming/src/main/scala/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/spark/streaming/Duration.scala @@ -1,6 +1,6 @@ package spark.streaming -class Duration (private val millis: Long) { +case class Duration (private val millis: Long) { def < (that: Duration): Boolean = (this.millis < that.millis) diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index dd64064138..482d01300d 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -21,14 +21,10 @@ extends Serializable { def ssc = self.ssc - def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { + private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { new HashPartitioner(numPartitions) } - /* ---------------------------------- */ - /* DStream operations for key-value pairs */ - /* ---------------------------------- */ - def groupByKey(): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner()) } @@ -69,59 +65,59 @@ extends Serializable { self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) } - def groupByKeyAndWindow(windowTime: Duration, slideTime: Duration): DStream[(K, Seq[V])] = { - groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner()) + def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): DStream[(K, Seq[V])] = { + groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) } def groupByKeyAndWindow( - windowTime: Duration, - slideTime: Duration, + windowDuration: Duration, + slideDuration: Duration, numPartitions: Int ): DStream[(K, Seq[V])] = { - groupByKeyAndWindow(windowTime, slideTime, defaultPartitioner(numPartitions)) + groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) } def groupByKeyAndWindow( - windowTime: Duration, - slideTime: Duration, + windowDuration: Duration, + slideDuration: Duration, partitioner: Partitioner ): DStream[(K, Seq[V])] = { - self.window(windowTime, slideTime).groupByKey(partitioner) + self.window(windowDuration, slideDuration).groupByKey(partitioner) } def reduceByKeyAndWindow( reduceFunc: (V, V) => V, - windowTime: Duration + windowDuration: Duration ): DStream[(K, V)] = { - reduceByKeyAndWindow(reduceFunc, windowTime, self.slideTime, defaultPartitioner()) + reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner()) } def reduceByKeyAndWindow( reduceFunc: (V, V) => V, - windowTime: Duration, - slideTime: Duration + windowDuration: Duration, + slideDuration: Duration ): DStream[(K, V)] = { - reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner()) + reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner()) } def reduceByKeyAndWindow( reduceFunc: (V, V) => V, - windowTime: Duration, - slideTime: Duration, + windowDuration: Duration, + slideDuration: Duration, numPartitions: Int ): DStream[(K, V)] = { - reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions)) + reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } def reduceByKeyAndWindow( reduceFunc: (V, V) => V, - windowTime: Duration, - slideTime: Duration, + windowDuration: Duration, + slideDuration: Duration, partitioner: Partitioner ): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) self.reduceByKey(cleanedReduceFunc, partitioner) - .window(windowTime, slideTime) + .window(windowDuration, slideDuration) .reduceByKey(cleanedReduceFunc, partitioner) } @@ -134,51 +130,51 @@ extends Serializable { def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, - windowTime: Duration, - slideTime: Duration + windowDuration: Duration, + slideDuration: Duration ): DStream[(K, V)] = { reduceByKeyAndWindow( - reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner()) + reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner()) } def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, - windowTime: Duration, - slideTime: Duration, + windowDuration: Duration, + slideDuration: Duration, numPartitions: Int ): DStream[(K, V)] = { reduceByKeyAndWindow( - reduceFunc, invReduceFunc, windowTime, slideTime, defaultPartitioner(numPartitions)) + reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, - windowTime: Duration, - slideTime: Duration, + windowDuration: Duration, + slideDuration: Duration, partitioner: Partitioner ): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) new ReducedWindowedDStream[K, V]( - self, cleanedReduceFunc, cleanedInvReduceFunc, windowTime, slideTime, partitioner) + self, cleanedReduceFunc, cleanedInvReduceFunc, windowDuration, slideDuration, partitioner) } def countByKeyAndWindow( - windowTime: Duration, - slideTime: Duration, + windowDuration: Duration, + slideDuration: Duration, numPartitions: Int = self.ssc.sc.defaultParallelism ): DStream[(K, Long)] = { self.map(x => (x._1, 1L)).reduceByKeyAndWindow( (x: Long, y: Long) => x + y, (x: Long, y: Long) => x - y, - windowTime, - slideTime, + windowDuration, + slideDuration, numPartitions ) } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 10845e3a5e..c04ed37de8 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -14,7 +14,7 @@ class Scheduler(ssc: StreamingContext) extends Logging { val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) - val checkpointWriter = if (ssc.checkpointInterval != null && ssc.checkpointDir != null) { + val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { new CheckpointWriter(ssc.checkpointDir) } else { null @@ -65,7 +65,7 @@ class Scheduler(ssc: StreamingContext) extends Logging { } private def doCheckpoint(time: Time) { - if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { + if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { val startTime = System.currentTimeMillis() ssc.graph.updateCheckpointData(time) checkpointWriter.write(new Checkpoint(ssc, time)) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index ee8314df3f..14500bdcb1 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -96,7 +96,7 @@ class StreamingContext private ( } } - protected[streaming] var checkpointInterval: Duration = if (isCheckpointPresent) cp_.checkpointInterval else null + protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null protected[streaming] var receiverJobThread: Thread = null protected[streaming] var scheduler: Scheduler = null @@ -121,10 +121,10 @@ class StreamingContext private ( if (directory != null) { sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory)) checkpointDir = directory - checkpointInterval = interval + checkpointDuration = interval } else { checkpointDir = null - checkpointInterval = null + checkpointDuration = null } } @@ -327,7 +327,7 @@ class StreamingContext private ( graph.validate() assert( - checkpointDir == null || checkpointInterval != null, + checkpointDir == null || checkpointDuration != null, "Checkpoint directory has been set, but the graph checkpointing interval has " + "not been set. Please use StreamingContext.checkpoint() to set the interval." ) @@ -337,8 +337,8 @@ class StreamingContext private ( * Starts the execution of the streams. */ def start() { - if (checkpointDir != null && checkpointInterval == null && graph != null) { - checkpointInterval = graph.batchDuration + if (checkpointDir != null && checkpointDuration == null && graph != null) { + checkpointDuration = graph.batchDuration } validate() diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 069df82e52..5daeb761dd 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -1,14 +1,15 @@ package spark.streaming /** - * This is a simple class that represents time. Internally, it represents time as UTC. - * The recommended way to create instances of Time is to use helper objects - * [[spark.streaming.Milliseconds]], [[spark.streaming.Seconds]], and [[spark.streaming.Minutes]]. - * @param millis Time in UTC. + * This is a simple class that represents an absolute instant of time. + * Internally, it represents time as the difference, measured in milliseconds, between the current + * time and midnight, January 1, 1970 UTC. This is the same format as what is returned by + * System.currentTimeMillis. */ +case class Time(private val millis: Long) { + + def milliseconds: Long = millis -class Time(private val millis: Long) { - def < (that: Time): Boolean = (this.millis < that.millis) def <= (that: Time): Boolean = (this.millis <= that.millis) @@ -38,11 +39,4 @@ class Time(private val millis: Long) { override def toString: String = (millis.toString + " ms") - def milliseconds: Long = millis -} - -/*private[streaming] object Time { - implicit def toTime(long: Long) = Time(long) -} -*/ - +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala index ca178fd384..ddb1bf6b28 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala @@ -18,13 +18,13 @@ class CoGroupedDStream[K : ClassManifest]( throw new IllegalArgumentException("Array of parents have different StreamingContexts") } - if (parents.map(_.slideTime).distinct.size > 1) { + if (parents.map(_.slideDuration).distinct.size > 1) { throw new IllegalArgumentException("Array of parents have different slide times") } override def dependencies = parents.toList - override def slideTime: Duration = parents.head.slideTime + override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = { val part = partitioner diff --git a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala index 76b9e58029..e993164f99 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala @@ -11,7 +11,7 @@ class FilteredDStream[T: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[T]] = { parent.getOrCompute(validTime).map(_.filter(filterFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala index 28e9a456ac..cabd34f5f2 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -12,7 +12,7 @@ class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest] override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[(K, U)]] = { parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala index ef305b66f1..a69af60589 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala @@ -11,7 +11,7 @@ class FlatMappedDStream[T: ClassManifest, U: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[U]] = { parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala index f8af0a38a7..ee69ea5177 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala @@ -11,7 +11,7 @@ class ForEachDStream[T: ClassManifest] ( override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[Unit]] = None diff --git a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala index 19cccea735..b589cbd4d5 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala @@ -9,7 +9,7 @@ class GlommedDStream[T: ClassManifest](parent: DStream[T]) override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[Array[T]]] = { parent.getOrCompute(validTime).map(_.glom()) diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala index 50f0f45796..980ca5177e 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala @@ -7,7 +7,7 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContex override def dependencies = List() - override def slideTime: Duration = { + override def slideDuration: Duration = { if (ssc == null) throw new Exception("ssc is null") if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null") ssc.graph.batchDuration diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala index e9ca668aa6..848afecfad 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala @@ -12,7 +12,7 @@ class MapPartitionedDStream[T: ClassManifest, U: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[U]] = { parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala index ebc7d0698b..6055aa6a05 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala @@ -12,7 +12,7 @@ class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[(K, U)]] = { parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala index 3af8e7ab88..20818a0cab 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala @@ -11,7 +11,7 @@ class MappedDStream[T: ClassManifest, U: ClassManifest] ( override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[U]] = { parent.getOrCompute(validTime).map(_.map[U](mapFunc)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala index a685a778ce..733d5c4a25 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -16,19 +16,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( parent: DStream[(K, V)], reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, - _windowTime: Duration, - _slideTime: Duration, + _windowDuration: Duration, + _slideDuration: Duration, partitioner: Partitioner ) extends DStream[(K,V)](parent.ssc) { - assert(_windowTime.isMultipleOf(parent.slideTime), - "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + assert(_windowDuration.isMultipleOf(parent.slideDuration), + "The window duration of ReducedWindowedDStream (" + _slideDuration + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) - assert(_slideTime.isMultipleOf(parent.slideTime), - "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + assert(_slideDuration.isMultipleOf(parent.slideDuration), + "The slide duration of ReducedWindowedDStream (" + _slideDuration + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) // Reduce each batch of data using reduceByKey which will be further reduced by window @@ -39,15 +39,15 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( super.persist(StorageLevel.MEMORY_ONLY_SER) reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - def windowTime: Duration = _windowTime + def windowDuration: Duration = _windowDuration override def dependencies = List(reducedStream) - override def slideTime: Duration = _slideTime + override def slideDuration: Duration = _slideDuration override val mustCheckpoint = true - override def parentRememberDuration: Duration = rememberDuration + windowTime + override def parentRememberDuration: Duration = rememberDuration + windowDuration override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { super.persist(storageLevel) @@ -66,11 +66,11 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( val invReduceF = invReduceFunc val currentTime = validTime - val currentWindow = new Interval(currentTime - windowTime + parent.slideTime, currentTime) - val previousWindow = currentWindow - slideTime + val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration, currentTime) + val previousWindow = currentWindow - slideDuration - logDebug("Window time = " + windowTime) - logDebug("Slide time = " + slideTime) + logDebug("Window time = " + windowDuration) + logDebug("Slide time = " + slideDuration) logDebug("ZeroTime = " + zeroTime) logDebug("Current window = " + currentWindow) logDebug("Previous window = " + previousWindow) @@ -87,11 +87,11 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( // // Get the RDDs of the reduced values in "old time steps" - val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideTime) + val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) logDebug("# old RDDs = " + oldRDDs.size) // Get the RDDs of the reduced values in "new time steps" - val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideTime, currentWindow.endTime) + val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime) logDebug("# new RDDs = " + newRDDs.size) // Get the RDD of the reduced value of the previous window diff --git a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala index 7612804b96..1f9548bfb8 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala @@ -15,7 +15,7 @@ class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest]( override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[(K,C)]] = { parent.getOrCompute(validTime) match { diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala index ce4f486825..a1ec2f5454 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -18,14 +18,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override val mustCheckpoint = true override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD - getOrCompute(validTime - slideTime) match { + getOrCompute(validTime - slideDuration) match { case Some(prevStateRDD) => { // If previous state RDD exists diff --git a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala index 5a2c5bc0f0..99660d9dee 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala @@ -11,7 +11,7 @@ class TransformedDStream[T: ClassManifest, U: ClassManifest] ( override def dependencies = List(parent) - override def slideTime: Duration = parent.slideTime + override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[U]] = { parent.getOrCompute(validTime).map(transformFunc(_, validTime)) diff --git a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala index 224a19842b..00bad5da34 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala @@ -17,13 +17,13 @@ class UnionDStream[T: ClassManifest](parents: Array[DStream[T]]) throw new IllegalArgumentException("Array of parents have different StreamingContexts") } - if (parents.map(_.slideTime).distinct.size > 1) { + if (parents.map(_.slideDuration).distinct.size > 1) { throw new IllegalArgumentException("Array of parents have different slide times") } override def dependencies = parents.toList - override def slideTime: Duration = parents.head.slideTime + override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() diff --git a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala index 45689b25ce..cbf0c88108 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala @@ -8,30 +8,30 @@ import spark.streaming.{Duration, Interval, Time, DStream} private[streaming] class WindowedDStream[T: ClassManifest]( parent: DStream[T], - _windowTime: Duration, - _slideTime: Duration) + _windowDuration: Duration, + _slideDuration: Duration) extends DStream[T](parent.ssc) { - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of WindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + if (!_windowDuration.isMultipleOf(parent.slideDuration)) + throw new Exception("The window duration of WindowedDStream (" + _slideDuration + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of WindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + if (!_slideDuration.isMultipleOf(parent.slideDuration)) + throw new Exception("The slide duration of WindowedDStream (" + _slideDuration + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") parent.persist(StorageLevel.MEMORY_ONLY_SER) - def windowTime: Duration = _windowTime + def windowDuration: Duration = _windowDuration override def dependencies = List(parent) - override def slideTime: Duration = _slideTime + override def slideDuration: Duration = _slideDuration - override def parentRememberDuration: Duration = rememberDuration + windowTime + override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[T]] = { - val currentWindow = new Interval(validTime - windowTime + parent.slideTime, validTime) + val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime) Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) } } diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index dc38ef4912..f9e03c607d 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -196,18 +196,18 @@ class BasicOperationsSuite extends TestSuiteBase { // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 // WindowedStream2 - assert(windowedStream2.generatedRDDs.contains(Seconds(10))) - assert(windowedStream2.generatedRDDs.contains(Seconds(8))) - assert(!windowedStream2.generatedRDDs.contains(Seconds(6))) + assert(windowedStream2.generatedRDDs.contains(Time(10000))) + assert(windowedStream2.generatedRDDs.contains(Time(8000))) + assert(!windowedStream2.generatedRDDs.contains(Time(6000))) // WindowedStream1 - assert(windowedStream1.generatedRDDs.contains(Seconds(10))) - assert(windowedStream1.generatedRDDs.contains(Seconds(4))) - assert(!windowedStream1.generatedRDDs.contains(Seconds(3))) + assert(windowedStream1.generatedRDDs.contains(Time(10000))) + assert(windowedStream1.generatedRDDs.contains(Time(4000))) + assert(!windowedStream1.generatedRDDs.contains(Time(3000))) // MappedStream - assert(mappedStream.generatedRDDs.contains(Seconds(10))) - assert(mappedStream.generatedRDDs.contains(Seconds(2))) - assert(!mappedStream.generatedRDDs.contains(Seconds(1))) + assert(mappedStream.generatedRDDs.contains(Time(10000))) + assert(mappedStream.generatedRDDs.contains(Time(2000))) + assert(!mappedStream.generatedRDDs.contains(Time(1000))) } } diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index 28bdd53c3c..a76f61d4ad 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -26,7 +26,7 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ def compute(validTime: Time): Option[RDD[T]] = { logInfo("Computing RDD for time " + validTime) - val index = ((validTime - zeroTime) / slideTime - 1).toInt + val index = ((validTime - zeroTime) / slideDuration - 1).toInt val selectedInput = if (index < input.size) input(index) else Seq[T]() val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 4bc5229465..fa117cfcf0 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -207,11 +207,11 @@ class WindowOperationsSuite extends TestSuiteBase { test("groupByKeyAndWindow") { val input = bigInput val expectedOutput = bigGroupByOutput.map(_.map(x => (x._1, x._2.toSet))) - val windowTime = Seconds(2) - val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt + val windowDuration = Seconds(2) + val slideDuration = Seconds(1) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { - s.groupByKeyAndWindow(windowTime, slideTime) + s.groupByKeyAndWindow(windowDuration, slideDuration) .map(x => (x._1, x._2.toSet)) .persist() } @@ -221,21 +221,21 @@ class WindowOperationsSuite extends TestSuiteBase { test("countByWindow") { val input = Seq(Seq(1), Seq(1), Seq(1, 2), Seq(0), Seq(), Seq() ) val expectedOutput = Seq( Seq(1), Seq(2), Seq(3), Seq(3), Seq(1), Seq(0)) - val windowTime = Seconds(2) - val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt - val operation = (s: DStream[Int]) => s.countByWindow(windowTime, slideTime) + val windowDuration = Seconds(2) + val slideDuration = Seconds(1) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val operation = (s: DStream[Int]) => s.countByWindow(windowDuration, slideDuration) testOperation(input, operation, expectedOutput, numBatches, true) } test("countByKeyAndWindow") { val input = Seq(Seq(("a", 1)), Seq(("b", 1), ("b", 2)), Seq(("a", 10), ("b", 20))) val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3))) - val windowTime = Seconds(2) - val slideTime = Seconds(1) - val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt + val windowDuration = Seconds(2) + val slideDuration = Seconds(1) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { - s.countByKeyAndWindow(windowTime, slideTime).map(x => (x._1, x._2.toInt)) + s.countByKeyAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt)) } testOperation(input, operation, expectedOutput, numBatches, true) } @@ -247,12 +247,12 @@ class WindowOperationsSuite extends TestSuiteBase { name: String, input: Seq[Seq[Int]], expectedOutput: Seq[Seq[Int]], - windowTime: Time = Seconds(2), - slideTime: Time = Seconds(1) + windowDuration: Duration = Seconds(2), + slideDuration: Duration = Seconds(1) ) { test("window - " + name) { - val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt - val operation = (s: DStream[Int]) => s.window(windowTime, slideTime) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val operation = (s: DStream[Int]) => s.window(windowDuration, slideDuration) testOperation(input, operation, expectedOutput, numBatches, true) } } @@ -261,13 +261,13 @@ class WindowOperationsSuite extends TestSuiteBase { name: String, input: Seq[Seq[(String, Int)]], expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = Seconds(2), - slideTime: Time = Seconds(1) + windowDuration: Duration = Seconds(2), + slideDuration: Duration = Seconds(1) ) { test("reduceByKeyAndWindow - " + name) { - val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() + s.reduceByKeyAndWindow(_ + _, windowDuration, slideDuration).persist() } testOperation(input, operation, expectedOutput, numBatches, true) } @@ -277,13 +277,13 @@ class WindowOperationsSuite extends TestSuiteBase { name: String, input: Seq[Seq[(String, Int)]], expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = Seconds(2), - slideTime: Time = Seconds(1) + windowDuration: Duration = Seconds(2), + slideDuration: Duration = Seconds(1) ) { test("reduceByKeyAndWindowInv - " + name) { - val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime) + s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration) .persist() .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing } -- cgit v1.2.3 From e3861ae3953d7cab66160833688c8baf84e835ad Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Wed, 9 Jan 2013 17:03:25 -0600 Subject: Provide and expose a default Hadoop Configuration. Any "hadoop.*" system properties will be passed along into configuration. --- core/src/main/scala/spark/SparkContext.scala | 18 ++++++++++++++---- .../main/scala/spark/api/java/JavaSparkContext.scala | 7 +++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bbf8272eb3..36e0938854 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -187,6 +187,18 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) + /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ + val hadoopConfiguration = { + val conf = new Configuration() + // Copy any "hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("hadoop.")) { + conf.set(key.substring("hadoop.".length), System.getProperty(key)) + } + val bufferSize = System.getProperty("spark.buffer.size", "65536") + conf.set("io.file.buffer.size", bufferSize) + conf + } + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -231,10 +243,8 @@ class SparkContext( valueClass: Class[V], minSplits: Int = defaultMinSplits ) : RDD[(K, V)] = { - val conf = new JobConf() + val conf = new JobConf(hadoopConfiguration) FileInputFormat.setInputPaths(conf, path) - val bufferSize = System.getProperty("spark.buffer.size", "65536") - conf.set("io.file.buffer.size", bufferSize) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) } @@ -276,7 +286,7 @@ class SparkContext( fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], - new Configuration) + hadoopConfiguration) } /** diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 88ab2846be..12e2a0bdac 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -355,6 +355,13 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def clearFiles() { sc.clearFiles() } + + /** + * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + */ + def hadoopConfiguration() { + sc.hadoopConfiguration + } } object JavaSparkContext { -- cgit v1.2.3 From 1a64432ba50904c3933d8a9539a619fc94b3b30b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 9 Jan 2013 20:30:36 -0800 Subject: Indicate success/failure in PySpark test script. --- python/run-tests | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/run-tests b/python/run-tests index da9e24cb1f..fcdd1e27a7 100755 --- a/python/run-tests +++ b/python/run-tests @@ -3,7 +3,24 @@ # Figure out where the Scala framework is installed FWDIR="$(cd `dirname $0`; cd ../; pwd)" +FAILED=0 + $FWDIR/pyspark pyspark/rdd.py +FAILED=$(($?||$FAILED)) + $FWDIR/pyspark -m doctest pyspark/broadcast.py +FAILED=$(($?||$FAILED)) + +if [[ $FAILED != 0 ]]; then + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 +else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color +fi # TODO: in the long-run, it would be nice to use a test runner like `nose`. +# The doctest fixtures are the current barrier to doing this. -- cgit v1.2.3 From d55f2b98822faa7d71f5fce2bfa980f8265e0610 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 9 Jan 2013 21:21:23 -0800 Subject: Use take() instead of takeSample() in PySpark kmeans example. This is a temporary change until we port takeSample(). --- python/examples/kmeans.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py index ad2be21178..72cf9f88c6 100644 --- a/python/examples/kmeans.py +++ b/python/examples/kmeans.py @@ -33,7 +33,9 @@ if __name__ == "__main__": K = int(sys.argv[3]) convergeDist = float(sys.argv[4]) - kPoints = data.takeSample(False, K, 34) + # TODO: change this after we port takeSample() + #kPoints = data.takeSample(False, K, 34) + kPoints = data.take(K) tempDist = 1.0 while tempDist > convergeDist: -- cgit v1.2.3 From 9930a95d217045c4c22c2575080a03e4b0fd2426 Mon Sep 17 00:00:00 2001 From: shane-huang Date: Thu, 10 Jan 2013 20:09:34 +0800 Subject: Modified Patch according to comments --- core/src/main/scala/spark/network/Connection.scala | 8 ++++---- .../main/scala/spark/network/ConnectionManager.scala | 9 ++++----- .../scala/spark/network/ConnectionManagerTest.scala | 20 ++++++++++++++------ 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 95096fd0ba..c193bf7c8d 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -136,10 +136,10 @@ extends Connection(SocketChannel.open, selector_) { if (chunk.isDefined) { messages += message // this is probably incorrect, it wont work as fifo if (!message.started) { - logDebug("Starting to send [" + message + "]") - message.started = true - message.startTime = System.currentTimeMillis - } + logDebug("Starting to send [" + message + "]") + message.started = true + message.startTime = System.currentTimeMillis + } return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index e7bd2d3bbd..36c01ad629 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -43,12 +43,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(20) + val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt) val serverChannel = ServerSocketChannel.open() val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] + val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] @@ -78,9 +78,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def run() { try { - while(!selectorThread.isInterrupted) { + while(!selectorThread.isInterrupted) { for( (connectionManagerId, sendingConnection) <- connectionRequests) { - //val sendingConnection = connectionRequests.dequeue sendingConnection.connect() addConnection(sendingConnection) connectionRequests -= connectionManagerId @@ -465,7 +464,7 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 10 second) + val g = Await.result(f, 1 second) if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 0e79c518e0..533e4610f3 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -13,8 +13,14 @@ import akka.util.duration._ private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { - if (args.length < 5) { - println("Usage: ConnectionManagerTest ") + // - the master URL + // - a list slaves to run connectionTest on + //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts + //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10 + //[count] - how many times to run, default is 3 + //[await time in seconds] : await time (in seconds), default is 600 + if (args.length < 2) { + println("Usage: ConnectionManagerTest [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ") System.exit(1) } @@ -29,14 +35,17 @@ private[spark] object ConnectionManagerTest extends Logging{ /*println("Slaves")*/ /*slaves.foreach(println)*/ - val tasknum = args(2).toInt + val tasknum = if (args.length > 2) args(2).toInt else slaves.length + val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 + val count = if (args.length > 4) args(4).toInt else 3 + val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second + println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( i => SparkEnv.get.connectionManager.id).collect() println("\nSlave ConnectionManagerIds") slaveConnManagerIds.foreach(println) println - val count = args(4).toInt (0 until count).foreach(i => { val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { val connManager = SparkEnv.get.connectionManager @@ -46,7 +55,6 @@ private[spark] object ConnectionManagerTest extends Logging{ None }) - val size = (args(3).toInt) * 1024 * 1024 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -56,7 +64,7 @@ private[spark] object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => Await.result(f, 999.second)) + val results = futures.map(f => Await.result(f, awaitTime)) val finishTime = System.currentTimeMillis Thread.sleep(5000) -- cgit v1.2.3 From 49c74ba2af2ab6fe5eda16dbcd35b30b46072a3a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 10 Jan 2013 07:45:12 -0800 Subject: Change PYSPARK_PYTHON_EXEC to PYSPARK_PYTHON. --- python/pyspark/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4439356c1f..e486f206b0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -52,7 +52,7 @@ class SparkContext(object): self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) - self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to -- cgit v1.2.3 From b15e8512793475eaeda7225a259db8aacd600741 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 10 Jan 2013 10:55:41 -0600 Subject: Check for AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY environment variables. For custom properties, use "spark.hadoop.*" as a prefix instead of just "hadoop.*". --- core/src/main/scala/spark/SparkContext.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 36e0938854..7b11955f1e 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -190,9 +190,16 @@ class SparkContext( /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { val conf = new Configuration() - // Copy any "hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("hadoop.")) { - conf.set(key.substring("hadoop.".length), System.getProperty(key)) + // Explicitly check for S3 environment variables + if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { + conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + } + // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("spark.hadoop.")) { + conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) } val bufferSize = System.getProperty("spark.buffer.size", "65536") conf.set("io.file.buffer.size", bufferSize) -- cgit v1.2.3 From d1864052c58ff1e58980729f7ccf00e630f815b9 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 10 Jan 2013 12:16:26 -0600 Subject: Fix invalid asInstanceOf cast. --- core/src/main/scala/spark/SparkContext.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7b11955f1e..d2a5b4757a 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -7,6 +7,7 @@ import java.net.{URI, URLClassLoader} import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConversions._ import akka.actor.Actor import akka.actor.Actor._ @@ -198,7 +199,7 @@ class SparkContext( conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("spark.hadoop.")) { + for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) } val bufferSize = System.getProperty("spark.buffer.size", "65536") -- cgit v1.2.3 From bd336f5f406386c929f2d1f9aecd7d5190a1a087 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 10 Jan 2013 17:13:04 -0800 Subject: Changed CoGroupRDD's hash map from Scala to Java. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index de0d9fad88..2e051c81c8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -1,7 +1,8 @@ package spark.rdd +import java.util.{HashMap => JHashMap} import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -71,7 +72,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size - val map = new HashMap[K, Seq[ArrayBuffer[Any]]] + val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) } -- cgit v1.2.3 From 2e914d99835487e867cac6add8be1dbd80dc693f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 10 Jan 2013 19:13:08 -0800 Subject: Formatting --- core/src/main/scala/spark/deploy/master/MasterWebUI.scala | 5 +++-- core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index a96b55d6f3..580014ef3f 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -8,11 +8,12 @@ import akka.util.duration._ import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ -import spark.deploy._ import cc.spray.http.MediaTypes -import JsonProtocol._ import cc.spray.typeconversion.SprayJsonSupport._ +import spark.deploy._ +import spark.deploy.JsonProtocol._ + private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/master/webui" diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index 84b6c16bd6..f9489d99fc 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -7,11 +7,12 @@ import akka.util.Timeout import akka.util.duration._ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ -import spark.deploy.{JsonProtocol, WorkerState, RequestWorkerState} import cc.spray.http.MediaTypes -import JsonProtocol._ import cc.spray.typeconversion.SprayJsonSupport._ +import spark.deploy.{WorkerState, RequestWorkerState} +import spark.deploy.JsonProtocol._ + private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/worker/webui" -- cgit v1.2.3 From 92625223066a5c28553d7710c6b14af56f64b560 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 10 Jan 2013 22:07:34 -0800 Subject: Activate hadoop2 profile in pom.xml with -Dhadoop=2 --- bagel/pom.xml | 6 ++++++ core/pom.xml | 6 ++++++ examples/pom.xml | 6 ++++++ pom.xml | 6 ++++++ repl-bin/pom.xml | 6 ++++++ repl/pom.xml | 6 ++++++ 6 files changed, 36 insertions(+) diff --git a/bagel/pom.xml b/bagel/pom.xml index 85b2077026..c3461fb889 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -77,6 +77,12 @@ hadoop2 + + + hadoop + 2 + + org.spark-project diff --git a/core/pom.xml b/core/pom.xml index 005d8fe498..c8ff625774 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -216,6 +216,12 @@ hadoop2 + + + hadoop + 2 + + org.apache.hadoop diff --git a/examples/pom.xml b/examples/pom.xml index 3f738a3f8c..d0b1e97747 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -77,6 +77,12 @@ hadoop2 + + + hadoop + 2 + + org.spark-project diff --git a/pom.xml b/pom.xml index ea5b9c9d05..ae87813d4e 100644 --- a/pom.xml +++ b/pom.xml @@ -502,6 +502,12 @@ hadoop2 + + + hadoop + 2 + + 2 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index fecb01f3cd..54ae20659e 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -115,6 +115,12 @@ hadoop2 + + + hadoop + 2 + + hadoop2 diff --git a/repl/pom.xml b/repl/pom.xml index 04b2c35beb..3e979b93a6 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -121,6 +121,12 @@ hadoop2 + + + hadoop + 2 + + hadoop2 -- cgit v1.2.3 From 3e6519a36e354f3623c5b968efe5217c7fcb242f Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 11 Jan 2013 11:24:20 -0600 Subject: Use hadoopConfiguration for default JobConf in PairRDDFunctions. --- core/src/main/scala/spark/PairRDDFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index ce48cea903..51c15837c4 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -557,7 +557,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf) { + conf: JobConf = new JobConf(self.context.hadoopConfiguration)) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug -- cgit v1.2.3 From 5c7a1272198c88a90a843bbda0c1424f92b7c12e Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 11 Jan 2013 11:25:11 -0600 Subject: Pass a new Configuration that wraps the default hadoopConfiguration. --- core/src/main/scala/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d2a5b4757a..f6b98c41bc 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -294,7 +294,7 @@ class SparkContext( fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], - hadoopConfiguration) + new Configuration(hadoopConfiguration)) } /** -- cgit v1.2.3 From 480c4139bbd2711e99f3a819c9ef164d8b3dcac0 Mon Sep 17 00:00:00 2001 From: Michael Heuer Date: Fri, 11 Jan 2013 11:24:48 -0600 Subject: add repositories section to simple job pom.xml --- docs/quick-start.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/quick-start.md b/docs/quick-start.md index 177cb14551..d46dc2da3f 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -200,6 +200,16 @@ To build the job, we also write a Maven `pom.xml` file that lists Spark as a dep Simple Project jar 1.0 + + + Spray.cc repository + http://repo.spray.cc + + + Typesafe repository + http://repo.typesafe.com/typesafe/releases + + org.spark-project -- cgit v1.2.3 From c063e8777ebaeb04056889064e9264edc019edbd Mon Sep 17 00:00:00 2001 From: Tyson Date: Fri, 11 Jan 2013 14:57:38 -0500 Subject: Added implicit json writers for JobDescription and ExecutorRunner --- .../src/main/scala/spark/deploy/JsonProtocol.scala | 23 +++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index f14f804b3a..732fa08064 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -1,6 +1,7 @@ package spark.deploy import master.{JobInfo, WorkerInfo} +import worker.ExecutorRunner import cc.spray.json._ /** @@ -30,6 +31,24 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "submitdate" -> JsString(obj.submitDate.toString)) } + implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] { + def write(obj: JobDescription) = JsObject( + "name" -> JsString(obj.name), + "cores" -> JsNumber(obj.cores), + "memoryperslave" -> JsNumber(obj.memoryPerSlave), + "user" -> JsString(obj.user) + ) + } + + implicit object ExecutorRunnerJsonFormat extends RootJsonWriter[ExecutorRunner] { + def write(obj: ExecutorRunner) = JsObject( + "id" -> JsNumber(obj.execId), + "memory" -> JsNumber(obj.memory), + "jobid" -> JsString(obj.jobId), + "jobdesc" -> obj.jobDesc.toJson.asJsObject + ) + } + implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] { def write(obj: MasterState) = JsObject( "url" -> JsString("spark://" + obj.uri), @@ -51,7 +70,9 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "cores" -> JsNumber(obj.cores), "coresused" -> JsNumber(obj.coresUsed), "memory" -> JsNumber(obj.memory), - "memoryused" -> JsNumber(obj.memoryUsed) + "memoryused" -> JsNumber(obj.memoryUsed), + "executors" -> JsArray(obj.executors.toList.map(_.toJson)), + "finishedexecutors" -> JsArray(obj.finishedExecutors.toList.map(_.toJson)) ) } } -- cgit v1.2.3 From 1731f1fed4f1369662b1a9fde850a3dcba738a59 Mon Sep 17 00:00:00 2001 From: Tyson Date: Fri, 11 Jan 2013 15:01:43 -0500 Subject: Added an optional format parameter for individual job queries and optimized the jobId query --- .../scala/spark/deploy/master/MasterWebUI.scala | 38 +++++++++++++++------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 580014ef3f..458ee2d665 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -38,20 +38,36 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct } } ~ path("job") { - parameter("jobId") { jobId => - completeWith { + parameters("jobId", 'format ?) { + case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) => val future = master ? RequestMasterState - future.map { state => - val masterState = state.asInstanceOf[MasterState] - - // A bit ugly an inefficient, but we won't have a number of jobs - // so large that it will make a significant difference. - (masterState.activeJobs ++ masterState.completedJobs).find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => null + val jobInfo = for (masterState <- future.mapTo[MasterState]) yield { + masterState.activeJobs.find(_.id == jobId) match { + case Some(job) => job + case _ => masterState.completedJobs.find(_.id == jobId) match { + case Some(job) => job + case _ => null + } + } + } + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(jobInfo.mapTo[JobInfo]) + } + case (jobId, _) => + completeWith { + val future = master ? RequestMasterState + future.map { state => + val masterState = state.asInstanceOf[MasterState] + + masterState.activeJobs.find(_.id == jobId) match { + case Some(job) => spark.deploy.master.html.job_details.render(job) + case _ => masterState.completedJobs.find(_.id == jobId) match { + case Some(job) => spark.deploy.master.html.job_details.render(job) + case _ => null + } + } } } - } } } ~ pathPrefix("static") { -- cgit v1.2.3 From 22445fbea9ed1575e49a1f9bb2251d98a57b9e4e Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Fri, 11 Jan 2013 13:30:49 -0800 Subject: attempt to sleep for more accurate time period, minor cleanup --- .../scala/spark/util/RateLimitedOutputStream.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index d11ed163ce..3050213709 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -1,8 +1,10 @@ package spark.util import java.io.OutputStream +import java.util.concurrent.TimeUnit._ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + val SyncIntervalNs = NANOSECONDS.convert(10, SECONDS) var lastSyncTime = System.nanoTime() var bytesWrittenSinceSync: Long = 0 @@ -28,20 +30,21 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu def waitToWrite(numBytes: Int) { while (true) { - val now = System.nanoTime() - val elapsed = math.max(now - lastSyncTime, 1) - val rate = bytesWrittenSinceSync.toDouble / (elapsed / 1.0e9) + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs if (rate < bytesPerSec) { // It's okay to write; just update some variables and return bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + (1e10).toLong) { - // Ten seconds have passed since lastSyncTime; let's resync + if (now > lastSyncTime + SyncIntervalNs) { + // Sync interval has passed; let's resync lastSyncTime = now bytesWrittenSinceSync = numBytes } - return } else { - Thread.sleep(5) + // Calculate how much time we should sleep to bring ourselves to the desired rate. + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) } } } @@ -53,4 +56,4 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu override def close() { out.close() } -} \ No newline at end of file +} -- cgit v1.2.3 From ff10b3aa0970cc7224adc6bc73d99a7ffa30219f Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Fri, 11 Jan 2013 21:03:57 -0800 Subject: add missing return --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 3050213709..ed459c2544 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -41,6 +41,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu lastSyncTime = now bytesWrittenSinceSync = numBytes } + return } else { // Calculate how much time we should sleep to bring ourselves to the desired rate. val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) -- cgit v1.2.3 From 0cfea7a2ec467717fbe110f9b15163bea2719575 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Fri, 11 Jan 2013 23:48:07 -0800 Subject: add unit test --- .../scala/spark/util/RateLimitedOutputStream.scala | 2 +- .../spark/util/RateLimitedOutputStreamSuite.scala | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index ed459c2544..16db7549b2 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -31,7 +31,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu def waitToWrite(numBytes: Int) { while (true) { val now = System.nanoTime - val elapsedSecs = SECONDS.convert(max(now - lastSyncTime, 1), NANOSECONDS) + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) val rate = bytesWrittenSinceSync.toDouble / elapsedSecs if (rate < bytesPerSec) { // It's okay to write; just update some variables and return diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala new file mode 100644 index 0000000000..1dc45e0433 --- /dev/null +++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala @@ -0,0 +1,22 @@ +package spark.util + +import org.scalatest.FunSuite +import java.io.ByteArrayOutputStream +import java.util.concurrent.TimeUnit._ + +class RateLimitedOutputStreamSuite extends FunSuite { + + private def benchmark[U](f: => U): Long = { + val start = System.nanoTime + f + System.nanoTime - start + } + + test("write") { + val underlying = new ByteArrayOutputStream + val data = "X" * 1000 + val stream = new RateLimitedOutputStream(underlying, 100) + val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } + assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) + } +} -- cgit v1.2.3 From 2c77eeebb66a3d1337d45b5001be2b48724f9fd5 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sat, 12 Jan 2013 00:13:45 -0800 Subject: correct test params --- core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala index 1dc45e0433..b392075482 100644 --- a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala +++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala @@ -14,8 +14,8 @@ class RateLimitedOutputStreamSuite extends FunSuite { test("write") { val underlying = new ByteArrayOutputStream - val data = "X" * 1000 - val stream = new RateLimitedOutputStream(underlying, 100) + val data = "X" * 41000 + val stream = new RateLimitedOutputStream(underlying, 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) } -- cgit v1.2.3 From ea20ae661888d871f70d5ed322cfe924c5a31dba Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sat, 12 Jan 2013 09:18:00 -0800 Subject: add one extra test --- core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala index b392075482..794063fb6d 100644 --- a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala +++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala @@ -18,5 +18,6 @@ class RateLimitedOutputStreamSuite extends FunSuite { val stream = new RateLimitedOutputStream(underlying, 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) + assert(underlying.toString("UTF-8") == data) } } -- cgit v1.2.3 From addff2c466d4b76043e612d4d28ab9de7f003298 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sat, 12 Jan 2013 09:57:29 -0800 Subject: add comment --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 16db7549b2..ed3d2b66bb 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -44,6 +44,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu return } else { // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) if (sleepTime > 0) Thread.sleep(sleepTime) } -- cgit v1.2.3 From bbc56d85ed4eb4c3a09b20d5457f704f4b8a70c4 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 12 Jan 2013 15:24:13 -0800 Subject: Rename environment variable for hadoop profiles to hadoopVersion --- bagel/pom.xml | 4 ++-- core/pom.xml | 4 ++-- examples/pom.xml | 4 ++-- pom.xml | 5 +++-- repl-bin/pom.xml | 4 ++-- repl/pom.xml | 4 ++-- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/bagel/pom.xml b/bagel/pom.xml index c3461fb889..5f58347204 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -47,7 +47,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -79,7 +79,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/core/pom.xml b/core/pom.xml index c8ff625774..ad9fdcde2c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -161,7 +161,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -218,7 +218,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/examples/pom.xml b/examples/pom.xml index d0b1e97747..3355deb6b7 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -47,7 +47,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -79,7 +79,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/pom.xml b/pom.xml index ae87813d4e..8f1af673a3 100644 --- a/pom.xml +++ b/pom.xml @@ -483,9 +483,10 @@ hadoop1 - !hadoop + !hadoopVersion + 1 @@ -504,7 +505,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index 54ae20659e..da91c0f3ab 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -72,7 +72,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -117,7 +117,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/repl/pom.xml b/repl/pom.xml index 3e979b93a6..38e883c7f8 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -74,7 +74,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -123,7 +123,7 @@ hadoop2 - hadoop + hadoopVersion 2 -- cgit v1.2.3 From ba06e9c97cc3f8723ffdc3895182c529d3bb2fb3 Mon Sep 17 00:00:00 2001 From: Eric Zhang Date: Sun, 13 Jan 2013 15:33:11 +0800 Subject: Update examples/src/main/scala/spark/examples/LocalLR.scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix spelling mistake --- examples/src/main/scala/spark/examples/LocalLR.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/spark/examples/LocalLR.scala b/examples/src/main/scala/spark/examples/LocalLR.scala index f2ac2b3e06..9553162004 100644 --- a/examples/src/main/scala/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/spark/examples/LocalLR.scala @@ -5,7 +5,7 @@ import spark.util.Vector object LocalLR { val N = 10000 // Number of data points - val D = 10 // Numer of dimensions + val D = 10 // Number of dimensions val R = 0.7 // Scaling factor val ITERATIONS = 5 val rand = new Random(42) -- cgit v1.2.3 From 88d8f11365db84d46ff456495c07f664c91d1896 Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Sun, 13 Jan 2013 00:45:52 -0800 Subject: Add missing dependency spray-json to Maven build --- core/pom.xml | 4 ++++ pom.xml | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index ad9fdcde2c..862d3ec37a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -71,6 +71,10 @@ cc.spray spray-server + + cc.spray + spray-json_${scala.version} + org.tomdz.twirl twirl-api diff --git a/pom.xml b/pom.xml index 8f1af673a3..751189a9d8 100644 --- a/pom.xml +++ b/pom.xml @@ -54,6 +54,7 @@ 0.9.0-incubating 2.0.3 1.0-M2.1 + 1.1.1 1.6.1 4.1.2 @@ -222,6 +223,11 @@ spray-server ${spray.version} + + cc.spray + spray-json_${scala.version} + ${spray.json.version} + org.tomdz.twirl twirl-api -- cgit v1.2.3 From 2305a2c1d91273a93ee6b571b0cd4bcaa1b2969d Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sun, 13 Jan 2013 10:01:56 -0800 Subject: more code cleanup --- .../scala/spark/util/RateLimitedOutputStream.scala | 63 +++++++++++----------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index ed3d2b66bb..10790a9eee 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -1,11 +1,14 @@ package spark.util +import scala.annotation.tailrec + import java.io.OutputStream import java.util.concurrent.TimeUnit._ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { val SyncIntervalNs = NANOSECONDS.convert(10, SECONDS) - var lastSyncTime = System.nanoTime() + val ChunkSize = 8192 + var lastSyncTime = System.nanoTime var bytesWrittenSinceSync: Long = 0 override def write(b: Int) { @@ -17,37 +20,13 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu write(bytes, 0, bytes.length) } - override def write(bytes: Array[Byte], offset: Int, length: Int) { - val CHUNK_SIZE = 8192 - var pos = 0 - while (pos < length) { - val writeSize = math.min(length - pos, CHUNK_SIZE) + @tailrec + override final def write(bytes: Array[Byte], offset: Int, length: Int) { + val writeSize = math.min(length - offset, ChunkSize) + if (writeSize > 0) { waitToWrite(writeSize) - out.write(bytes, offset + pos, writeSize) - pos += writeSize - } - } - - def waitToWrite(numBytes: Int) { - while (true) { - val now = System.nanoTime - val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) - val rate = bytesWrittenSinceSync.toDouble / elapsedSecs - if (rate < bytesPerSec) { - // It's okay to write; just update some variables and return - bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + SyncIntervalNs) { - // Sync interval has passed; let's resync - lastSyncTime = now - bytesWrittenSinceSync = numBytes - } - return - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) - val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) - if (sleepTime > 0) Thread.sleep(sleepTime) - } + out.write(bytes, offset, writeSize) + write(bytes, offset + writeSize, length) } } @@ -58,4 +37,26 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu override def close() { out.close() } + + @tailrec + private def waitToWrite(numBytes: Int) { + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + SyncIntervalNs) { + // Sync interval has passed; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + } else { + // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) + waitToWrite(numBytes) + } + } } -- cgit v1.2.3 From c31931af7eb01fbe2bb276bb6f428248128832b0 Mon Sep 17 00:00:00 2001 From: Ryan LeCompte Date: Sun, 13 Jan 2013 10:39:47 -0800 Subject: switch to uppercase constants --- core/src/main/scala/spark/util/RateLimitedOutputStream.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala index 10790a9eee..e3f00ea8c7 100644 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -6,8 +6,8 @@ import java.io.OutputStream import java.util.concurrent.TimeUnit._ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { - val SyncIntervalNs = NANOSECONDS.convert(10, SECONDS) - val ChunkSize = 8192 + val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + val CHUNK_SIZE = 8192 var lastSyncTime = System.nanoTime var bytesWrittenSinceSync: Long = 0 @@ -22,7 +22,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu @tailrec override final def write(bytes: Array[Byte], offset: Int, length: Int) { - val writeSize = math.min(length - offset, ChunkSize) + val writeSize = math.min(length - offset, CHUNK_SIZE) if (writeSize > 0) { waitToWrite(writeSize) out.write(bytes, offset, writeSize) @@ -46,7 +46,7 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu if (rate < bytesPerSec) { // It's okay to write; just update some variables and return bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + SyncIntervalNs) { + if (now > lastSyncTime + SYNC_INTERVAL) { // Sync interval has passed; let's resync lastSyncTime = now bytesWrittenSinceSync = numBytes -- cgit v1.2.3 From be7166146bf5692369272b85622d5316eccfd8e6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 13 Jan 2013 15:27:28 -0800 Subject: Removed the use of getOrElse to avoid Scala wrapper for every call. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 2e051c81c8..ce5f171911 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -1,8 +1,8 @@ package spark.rdd import java.util.{HashMap => JHashMap} +import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -74,7 +74,14 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) val numRdds = split.deps.size val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) + val seq = map.get(k) + if (seq != null) { + seq + } else { + val seq = Array.fill(numRdds)(new ArrayBuffer[Any]) + map.put(k, seq) + seq + } } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplit) => { @@ -94,6 +101,6 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) } } - map.iterator + JavaConversions.mapAsScalaMap(map).iterator } } -- cgit v1.2.3 From 0a2e33334125cb3ae5e54f8333ea5608779399fc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 13 Jan 2013 16:18:39 -0800 Subject: Removed stream id from the constructor of NetworkReceiver to make it easier for PluggableNetworkInputDStream. --- .../examples/twitter/TwitterInputDStream.scala | 15 ++++---- .../spark/streaming/NetworkInputTracker.scala | 34 +++++++++++++----- .../streaming/dstream/FlumeInputDStream.scala | 5 ++- .../streaming/dstream/KafkaInputDStream.scala | 6 ++-- .../streaming/dstream/NetworkInputDStream.scala | 42 ++++++++++++++++------ .../spark/streaming/dstream/RawInputDStream.scala | 6 ++-- .../streaming/dstream/SocketInputDStream.scala | 5 ++- 7 files changed, 76 insertions(+), 37 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala index c7e4855f3b..99ed4cdc1c 100644 --- a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala +++ b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala @@ -23,16 +23,17 @@ class TwitterInputDStream( ) extends NetworkInputDStream[Status](ssc_) { override def createReceiver(): NetworkReceiver[Status] = { - new TwitterReceiver(id, username, password, filters, storageLevel) + new TwitterReceiver(username, password, filters, storageLevel) } } -class TwitterReceiver(streamId: Int, - username: String, - password: String, - filters: Seq[String], - storageLevel: StorageLevel - ) extends NetworkReceiver[Status](streamId) { +class TwitterReceiver( + username: String, + password: String, + filters: Seq[String], + storageLevel: StorageLevel + ) extends NetworkReceiver[Status] { + var twitterStream: TwitterStream = _ lazy val blockGenerator = new BlockGenerator(storageLevel) diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index a6ab44271f..e4152f3a61 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -18,7 +18,10 @@ private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: Act private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage - +/** + * This class manages the execution of the receivers of NetworkInputDStreams. + */ +private[streaming] class NetworkInputTracker( @transient ssc: StreamingContext, @transient networkInputStreams: Array[NetworkInputDStream[_]]) @@ -32,16 +35,20 @@ class NetworkInputTracker( var currentTime: Time = null + /** Start the actor and receiver execution thread. */ def start() { ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") receiverExecutor.start() } + /** Stop the receiver execution thread. */ def stop() { + // TODO: stop the actor as well receiverExecutor.interrupt() receiverExecutor.stopReceivers() } + /** Return all the blocks received from a receiver. */ def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { val queue = receivedBlockIds.synchronized { receivedBlockIds.getOrElse(receiverId, new Queue[String]()) @@ -53,6 +60,7 @@ class NetworkInputTracker( result.toArray } + /** Actor to receive messages from the receivers. */ private class NetworkInputTrackerActor extends Actor { def receive = { case RegisterReceiver(streamId, receiverActor) => { @@ -83,7 +91,8 @@ class NetworkInputTracker( } } } - + + /** This thread class runs all the receivers on the cluster. */ class ReceiverExecutor extends Thread { val env = ssc.env @@ -97,13 +106,22 @@ class NetworkInputTracker( stopReceivers() } } - + + /** + * Get the receivers from the NetworkInputDStreams, distributes them to the + * worker nodes as a parallel collection, and runs them. + */ def startReceivers() { - val receivers = networkInputStreams.map(_.createReceiver()) + val receivers = networkInputStreams.map(nis => { + val rcvr = nis.createReceiver() + rcvr.setStreamId(nis.id) + rcvr + }) // Right now, we only honor preferences if all receivers have them val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _) + // Create the parallel collection of receivers to distributed them on the worker nodes val tempRDD = if (hasLocationPreferences) { val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString))) @@ -113,21 +131,21 @@ class NetworkInputTracker( ssc.sc.makeRDD(receivers, receivers.size) } + // Function to start the receiver on the worker node val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { if (!iterator.hasNext) { throw new Exception("Could not start receiver as details not found.") } iterator.next().start() } + // Distribute the receivers and start them ssc.sc.runJob(tempRDD, startReceiver) } + /** Stops the receivers. */ def stopReceivers() { - //implicit val ec = env.actorSystem.dispatcher + // Signal the receivers to stop receiverInfo.values.foreach(_ ! StopReceiver) - //val listOfFutures = receiverInfo.values.map(_.ask(StopReceiver)(timeout)).toList - //val futureOfList = Future.sequence(listOfFutures) - //Await.result(futureOfList, timeout) } } } diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala index ca70e72e56..efc7058480 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala @@ -26,7 +26,7 @@ class FlumeInputDStream[T: ClassManifest]( ) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { - new FlumeReceiver(id, host, port, storageLevel) + new FlumeReceiver(host, port, storageLevel) } } @@ -112,11 +112,10 @@ class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { * Flume Avro interface.*/ private[streaming] class FlumeReceiver( - streamId: Int, host: String, port: Int, storageLevel: StorageLevel - ) extends NetworkReceiver[SparkFlumeEvent](streamId) { + ) extends NetworkReceiver[SparkFlumeEvent] { lazy val blockGenerator = new BlockGenerator(storageLevel) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 25988a2ce7..2b4740bdf7 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -96,15 +96,15 @@ class KafkaInputDStream[T: ClassManifest]( } */ def createReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) + new KafkaReceiver(host, port, groupId, topics, initialOffsets, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, +class KafkaReceiver(host: String, port: Int, groupId: String, topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], - storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { + storageLevel: StorageLevel) extends NetworkReceiver[Any] { // Timeout for establishing a connection to Zookeper in ms. val ZK_TIMEOUT = 10000 diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index 18e62a0e33..aa6be95f30 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -17,6 +17,15 @@ import akka.util.duration._ import spark.streaming.util.{RecurringTimer, SystemClock} import java.util.concurrent.ArrayBlockingQueue +/** + * Abstract class for defining any InputDStream that has to start a receiver on worker + * nodes to receive external data. Specific implementations of NetworkInputDStream must + * define the createReceiver() function that creates the receiver object of type + * [[spark.streaming.dstream.NetworkReceiver]] that will be sent to the workers to receive + * data. + * @param ssc_ Streaming context that will execute this input stream + * @tparam T Class type of the object of this stream + */ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { @@ -25,7 +34,7 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming val id = ssc.getNewNetworkStreamId() /** - * This method creates the receiver object that will be sent to the workers + * Creates the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation * of a NetworkInputDStream. */ @@ -48,7 +57,11 @@ private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverM private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage -abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { +/** + * Abstract class of a receiver that can be run on worker nodes to receive external data. See + * [[spark.streaming.dstream.NetworkInputDStream]] for an explanation. + */ +abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Logging { initLogging() @@ -59,17 +72,22 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri lazy protected val receivingThread = Thread.currentThread() - /** This method will be called to start receiving data. */ + protected var streamId: Int = -1 + + /** + * This method will be called to start receiving data. All your receiver + * starting code should be implemented by defining this function. + */ protected def onStart() /** This method will be called to stop receiving data. */ protected def onStop() - /** This method conveys a placement preference (hostname) for this receiver. */ + /** Conveys a placement preference (hostname) for this receiver. */ def getLocationPreference() : Option[String] = None /** - * This method starts the receiver. First is accesses all the lazy members to + * Starts the receiver. First is accesses all the lazy members to * materialize them. Then it calls the user-defined onStart() method to start * other threads, etc required to receiver the data. */ @@ -92,7 +110,7 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri } /** - * This method stops the receiver. First it interrupts the main receiving thread, + * Stops the receiver. First it interrupts the main receiving thread, * that is, the thread that called receiver.start(). Then it calls the user-defined * onStop() method to stop other threads and/or do cleanup. */ @@ -103,7 +121,7 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri } /** - * This method stops the receiver and reports to exception to the tracker. + * Stops the receiver and reports to exception to the tracker. * This should be called whenever an exception has happened on any thread * of the receiver. */ @@ -115,7 +133,7 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri /** - * This method pushes a block (as iterator of values) into the block manager. + * Pushes a block (as iterator of values) into the block manager. */ def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { val buffer = new ArrayBuffer[T] ++ iterator @@ -125,7 +143,7 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri } /** - * This method pushes a block (as bytes) into the block manager. + * Pushes a block (as bytes) into the block manager. */ def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { env.blockManager.putBytes(blockId, bytes, level) @@ -157,6 +175,10 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri } } + protected[streaming] def setStreamId(id: Int) { + streamId = id + } + /** * Batches objects created by a [[spark.streaming.NetworkReceiver]] and puts them into * appropriately named blocks at regular intervals. This class starts two threads, @@ -202,7 +224,7 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri val newBlockBuffer = currentBuffer currentBuffer = new ArrayBuffer[T] if (newBlockBuffer.size > 0) { - val blockId = "input-" + NetworkReceiver.this.streamId + "- " + (time - blockInterval) + val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval) val newBlock = createBlock(blockId, newBlockBuffer.toIterator) blocksForPushing.add(newBlock) } diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala index aa2f31cea8..290fab1ce0 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -26,13 +26,13 @@ class RawInputDStream[T: ClassManifest]( ) extends NetworkInputDStream[T](ssc_ ) with Logging { def createReceiver(): NetworkReceiver[T] = { - new RawNetworkReceiver(id, host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] + new RawNetworkReceiver(host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel) - extends NetworkReceiver[Any](streamId) { +class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) + extends NetworkReceiver[Any] { var blockPushingThread: Thread = null diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala index 8e4b20ea4c..d42027092b 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala @@ -16,18 +16,17 @@ class SocketInputDStream[T: ClassManifest]( ) extends NetworkInputDStream[T](ssc_) { def createReceiver(): NetworkReceiver[T] = { - new SocketReceiver(id, host, port, bytesToObjects, storageLevel) + new SocketReceiver(host, port, bytesToObjects, storageLevel) } } private[streaming] class SocketReceiver[T: ClassManifest]( - streamId: Int, host: String, port: Int, bytesToObjects: InputStream => Iterator[T], storageLevel: StorageLevel - ) extends NetworkReceiver[T](streamId) { + ) extends NetworkReceiver[T] { lazy protected val blockGenerator = new BlockGenerator(storageLevel) -- cgit v1.2.3 From 72408e8dfacc24652f376d1ee4dd6f04edb54804 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 13 Jan 2013 19:34:07 -0800 Subject: Make filter preserve partitioner info, since it can --- core/src/main/scala/spark/rdd/FilteredRDD.scala | 3 ++- core/src/test/scala/spark/PartitioningSuite.scala | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index b148da28de..d46549b8b6 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -7,5 +7,6 @@ private[spark] class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) + override val partitioner = prev.partitioner // Since filter cannot change a partition's keys override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f) -} \ No newline at end of file +} diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index f09b602a7b..eb3c8f238f 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -106,6 +106,11 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter { assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) + + assert(grouped2.map(_ => 1).partitioner === None) + assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner) + assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner) + assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner) } test("partitioning Java arrays should fail") { -- cgit v1.2.3 From 0dbd411a562396e024c513936fde46b0d2f6d59d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 13 Jan 2013 21:08:35 -0800 Subject: Added documentation for PairDStreamFunctions. --- core/src/main/scala/spark/PairRDDFunctions.scala | 6 +- docs/streaming-programming-guide.md | 45 ++-- .../src/main/scala/spark/streaming/DStream.scala | 35 ++- .../spark/streaming/PairDStreamFunctions.scala | 293 ++++++++++++++++++++- .../scala/spark/streaming/util/RawTextHelper.scala | 2 +- 5 files changed, 331 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 07ae2d647c..d95b66ad78 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -199,9 +199,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { this.cogroup(other, partitioner).flatMapValues { diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 05a88ce7bd..b6da7af654 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -43,7 +43,7 @@ A complete list of input sources is available in the [StreamingContext API docum # DStream Operations -Once an input stream has been created, you can transform it using _stream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the stream by writing data out to an external source. +Once an input DStream has been created, you can transform it using _DStream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the DStream by writing data out to an external source. ## Transformations @@ -53,11 +53,11 @@ DStreams support many of the transformations available on normal Spark RDD's:
    TransformationMeaning
    map(func) Return a new stream formed by passing each element of the source through a function func. Returns a new DStream formed by passing each element of the source through a function func.
    filter(func) Return a new stream formed by selecting those elements of the source on which func returns true. Returns a new stream formed by selecting those elements of the source on which func returns true.
    flatMap(func)
    cogroup(otherStream, [numTasks]) When called on streams of type (K, V) and (K, W), returns a stream of (K, Seq[V], Seq[W]) tuples. This operation is also called groupWith. When called on DStream of type (K, V) and (K, W), returns a DStream of (K, Seq[V], Seq[W]) tuples.
    reduce(func) Create a new single-element stream by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. Returns a new DStream of single-element RDDs by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel.
    transform(func) Returns a new DStream by applying func (a RDD-to-RDD function) to every RDD of the stream. This can be used to do arbitrary RDD operations on the DStream.
    -Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowTime, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. +Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowDuration, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. - - + - - + - - + -
    TransformationMeaning
    window(windowTime, slideTime) Return a new stream which is computed based on windowed batches of the source stream. windowTime is the width of the window and slideTime is the frequency during which the window is calculated. Both times must be multiples of the batch interval. + window(windowDuration, slideTime) Return a new stream which is computed based on windowed batches of the source stream. windowDuration is the width of the window and slideTime is the frequency during which the window is calculated. Both times must be multiples of the batch interval.
    countByWindow(windowTime, slideTime) Return a sliding count of elements in the stream. windowTime and slideTime are exactly as defined in window(). + countByWindow(windowDuration, slideTime) Return a sliding count of elements in the stream. windowDuration and slideDuration are exactly as defined in window().
    reduceByWindow(func, windowTime, slideTime) Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using func. The function should be associative so that it can be computed correctly in parallel. windowTime and slideTime are exactly as defined in window(). + reduceByWindow(func, windowDuration, slideDuration) Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using func. The function should be associative so that it can be computed correctly in parallel. windowDuration and slideDuration are exactly as defined in window().
    groupByKeyAndWindow(windowTime, slideTime, [numTasks]) + groupByKeyAndWindow(windowDuration, slideDuration, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs over a sliding window.
    -Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. windowTime and slideTime are exactly as defined in window(). +Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. windowDuration and slideDuration are exactly as defined in window().
    reduceByKeyAndWindow(func, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function over batches within a sliding window. Like in groupByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. - windowTime and slideTime are exactly as defined in window(). + windowDuration and slideDuration are exactly as defined in window().
    countByKeyAndWindow([numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Int) pairs where the values for each key are the count within a sliding window. Like in countByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. - windowTime and slideTime are exactly as defined in window(). + windowDuration and slideDuration are exactly as defined in window().
    +A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#spark.streaming.DStream) and [PairDStreamFunctions](api/streaming/index.html#spark.streaming.PairDStreamFunctions). ## Output Operations When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: @@ -144,7 +149,7 @@ When an output operator is called, it triggers the computation of a stream. Curr - + @@ -155,18 +160,18 @@ When an output operator is called, it triggers the computation of a stream. Curr - - + - +
    OperatorMeaning
    foreachRDD(func) foreach(func) The fundamental output operator. Applies a function, func, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system.
    saveAsObjectFiles(prefix, [suffix]) Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". + Save this DStream's contents as a SequenceFile of serialized objects. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    saveAsTextFiles(prefix, [suffix]) Save this DStream's contents as a text files. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". Save this DStream's contents as a text files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    saveAsHadoopFiles(prefix, [suffix]) Save this DStream's contents as a Hadoop file. The file name at each batch interval is calculated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". Save this DStream's contents as a Hadoop file. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index c89fb7723e..d94548a4f3 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -471,7 +471,7 @@ abstract class DStream[T: ClassManifest] ( * Returns a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _) + def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _) /** * Applies a function to each RDD in this DStream. This is an output operator, so @@ -529,17 +529,16 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream which is computed based on windowed batches of this DStream. * The new DStream generates RDDs with the same interval as this DStream. * @param windowDuration width of the window; must be a multiple of this DStream's interval. - * @return */ def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration) /** * Return a new DStream which is computed based on windowed batches of this DStream. - * @param windowDuration duration (i.e., width) of the window; - * must be a multiple of this DStream's interval + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's interval + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = { new WindowedDStream(this, windowDuration, slideDuration) @@ -548,16 +547,22 @@ abstract class DStream[T: ClassManifest] ( /** * Returns a new DStream which computed based on tumbling window on this DStream. * This is equivalent to window(batchTime, batchTime). - * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + * @param batchTime tumbling window duration; must be a multiple of this DStream's + * batching interval */ def tumble(batchTime: Duration): DStream[T] = window(batchTime, batchTime) /** * Returns a new DStream in which each RDD has a single element generated by reducing all - * elements in a window over this DStream. windowDuration and slideDuration are as defined in the - * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc) + * elements in a window over this DStream. windowDuration and slideDuration are as defined + * in the window() operation. This is equivalent to + * window(windowDuration, slideDuration).reduce(reduceFunc) */ - def reduceByWindow(reduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration): DStream[T] = { + def reduceByWindow( + reduceFunc: (T, T) => T, + windowDuration: Duration, + slideDuration: Duration + ): DStream[T] = { this.window(windowDuration, slideDuration).reduce(reduceFunc) } @@ -577,8 +582,8 @@ abstract class DStream[T: ClassManifest] ( * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the * window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Int] = { - this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) + def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { + this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) } /** @@ -612,6 +617,8 @@ abstract class DStream[T: ClassManifest] ( /** * Saves each RDD in this DStream as a Sequence file of serialized objects. + * The file name at each batch interval is generated based on `prefix` and + * `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsObjectFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { @@ -622,7 +629,9 @@ abstract class DStream[T: ClassManifest] ( } /** - * Saves each RDD in this DStream as at text file, using string representation of elements. + * Saves each RDD in this DStream as at text file, using string representation + * of elements. The file name at each batch interval is generated based on + * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsTextFiles(prefix: String, suffix: String = "") { val saveFunc = (rdd: RDD[T], time: Time) => { diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 482d01300d..3952457339 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -25,34 +25,76 @@ extends Serializable { new HashPartitioner(numPartitions) } + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. Hash partitioning is + * used to generate the RDDs with Spark's default number of partitions. + */ def groupByKey(): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner()) } + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. Hash partitioning is + * used to generate the RDDs with `numPartitions` partitions. + */ def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]] + * is used to control the partitioning of each RDD. + */ def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { val createCombiner = (v: V) => ArrayBuffer[V](v) val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) - combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner).asInstanceOf[DStream[(K, Seq[V])]] + combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner) + .asInstanceOf[DStream[(K, Seq[V])]] } + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + */ def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner()) } + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + */ def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * [[spark.Partitioner]] is used to control the partitioning of each RDD. + */ def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) } + /** + * Generic function to combine elements of each key in DStream's RDDs using custom function. + * This is similar to the combineByKey for RDDs. Please refer to combineByKey in + * [[spark.PairRDDFunctions]] for more information. + */ def combineByKey[C: ClassManifest]( createCombiner: V => C, mergeValue: (C, V) => C, @@ -61,14 +103,52 @@ extends Serializable { new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner) } + /** + * Creates a new DStream by counting the number of values of each key in each RDD + * of `this` DStream. Hash partitioning is used to generate the RDDs with Spark's + * `numPartitions` partitions. + */ def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = { self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * The new DStream generates RDDs with the same interval as this DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ + def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Seq[V])] = { + groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner()) + } + + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): DStream[(K, Seq[V])] = { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def groupByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -77,6 +157,16 @@ extends Serializable { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def groupByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -85,6 +175,15 @@ extends Serializable { self.window(windowDuration, slideDuration).groupByKey(partitioner) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * The new DStream generates RDDs with the same interval as this DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration @@ -92,6 +191,17 @@ extends Serializable { reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration, @@ -100,6 +210,18 @@ extends Serializable { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration, @@ -109,6 +231,17 @@ extends Serializable { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration, @@ -121,12 +254,23 @@ extends Serializable { .reduceByKey(cleanedReduceFunc, partitioner) } - // This method is the efficient sliding window reduce operation, - // which requires the specification of an inverse reduce function, - // so that new elements introduced in the window can be "added" using - // reduceFunc to the previous window's result and old elements can be - // "subtracted using invReduceFunc. - + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, @@ -138,6 +282,24 @@ extends Serializable { reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner()) } + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, @@ -150,6 +312,23 @@ extends Serializable { reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, @@ -164,6 +343,16 @@ extends Serializable { self, cleanedReduceFunc, cleanedInvReduceFunc, windowDuration, slideDuration, partitioner) } + /** + * Creates a new DStream by counting the number of values for each key over a window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def countByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -179,17 +368,30 @@ extends Serializable { ) } - // TODO: - // - // - // - // + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. Hash partitioning is used to generate the RDDs with Spark's default + * number of partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S] ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner()) } + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param numPartitions Number of partitions of each RDD in the new DStream. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int @@ -197,6 +399,15 @@ extends Serializable { updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) } + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. [[spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner @@ -207,6 +418,19 @@ extends Serializable { updateStateByKey(newUpdateFunc, partitioner, true) } + /** + * Creates a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key from + * `this` DStream. [[spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. Note, that + * this function may generate a different a tuple with a different key + * than the input key. It is up to the developer to decide whether to + * remember the partitioner despite the key being changed. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. + * @tparam S State type + */ def updateStateByKey[S <: AnyRef : ClassManifest]( updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, @@ -226,10 +450,24 @@ extends Serializable { new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) } + /** + * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for + * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD + * will contains a tuple with the list of values for that key in both RDDs. + * HashPartitioner is used to partition each generated RDD into default number of partitions. + */ def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { cogroup(other, defaultPartitioner()) } + /** + * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for + * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD + * will contains a tuple with the list of values for that key in both RDDs. + * Partitioner is used to partition each generated RDD. + */ def cogroup[W: ClassManifest]( other: DStream[(K, W)], partitioner: Partitioner @@ -249,11 +487,24 @@ extends Serializable { } } + /** + * Joins `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by joining RDDs from `this` and `other` DStreams. HashPartitioner is used + * to partition each generated RDD into default number of partitions. + */ def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = { join[W](other, defaultPartitioner()) } - def join[W: ClassManifest](other: DStream[(K, W)], partitioner: Partitioner): DStream[(K, (V, W))] = { + /** + * Joins `this` DStream with `other` DStream, that is, each RDD of the new DStream will + * be generated by joining RDDs from `this` and other DStream. Uses the given + * Partitioner to partition each generated RDD. + */ + def join[W: ClassManifest]( + other: DStream[(K, W)], + partitioner: Partitioner + ): DStream[(K, (V, W))] = { this.cogroup(other, partitioner) .flatMapValues{ case (vs, ws) => @@ -261,6 +512,10 @@ extends Serializable { } } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ def saveAsHadoopFiles[F <: OutputFormat[K, V]]( prefix: String, suffix: String @@ -268,6 +523,10 @@ extends Serializable { saveAsHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ def saveAsHadoopFiles( prefix: String, suffix: String, @@ -283,6 +542,10 @@ extends Serializable { self.foreach(saveFunc) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( prefix: String, suffix: String @@ -290,6 +553,10 @@ extends Serializable { saveAsNewAPIHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ def saveAsNewAPIHadoopFiles( prefix: String, suffix: String, diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala index f31ae39a16..03749d4a94 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala @@ -81,7 +81,7 @@ object RawTextHelper { * before real workload starts. */ def warmUp(sc: SparkContext) { - for(i <- 0 to 4) { + for(i <- 0 to 1) { sc.parallelize(1 to 200000, 1000) .map(_ % 1331).map(_.toString) .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) -- cgit v1.2.3 From 6cc8592f26553525e11213830b596fc397243439 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 13 Jan 2013 21:20:49 -0800 Subject: Fixed bug --- streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index fa117cfcf0..f9ba1f20f0 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -224,7 +224,9 @@ class WindowOperationsSuite extends TestSuiteBase { val windowDuration = Seconds(2) val slideDuration = Seconds(1) val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val operation = (s: DStream[Int]) => s.countByWindow(windowDuration, slideDuration) + val operation = (s: DStream[Int]) => { + s.countByWindow(windowDuration, slideDuration).map(_.toInt) + } testOperation(input, operation, expectedOutput, numBatches, true) } -- cgit v1.2.3 From f90f794cde479f4de425e9be0158a136a57666a2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 13 Jan 2013 21:25:57 -0800 Subject: Minor name fix --- streaming/src/main/scala/spark/streaming/DStream.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index d94548a4f3..fbe3cebd6d 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -547,10 +547,10 @@ abstract class DStream[T: ClassManifest] ( /** * Returns a new DStream which computed based on tumbling window on this DStream. * This is equivalent to window(batchTime, batchTime). - * @param batchTime tumbling window duration; must be a multiple of this DStream's + * @param batchDuration tumbling window duration; must be a multiple of this DStream's * batching interval */ - def tumble(batchTime: Duration): DStream[T] = window(batchTime, batchTime) + def tumble(batchDuration: Duration): DStream[T] = window(batchDuration, batchDuration) /** * Returns a new DStream in which each RDD has a single element generated by reducing all -- cgit v1.2.3 From 131be5d62ef6b770de5106eb268a45bca385b599 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Jan 2013 03:28:25 -0800 Subject: Fixed bug in RDD checkpointing. --- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 86c63ca2f4..6f00f6ac73 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -80,12 +80,12 @@ private[spark] object CheckpointRDD extends Logging { val serializer = SparkEnv.get.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) serializeStream.writeAll(iterator) - fileOutputStream.close() + serializeStream.close() if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.delete(finalOutputPath, true)) { throw new IOException("Checkpoint failed: failed to delete earlier output of task " - + context.attemptId); + + context.attemptId) } if (!fs.rename(tempOutputPath, finalOutputPath)) { throw new IOException("Checkpoint failed: failed to save output of task: " @@ -119,7 +119,7 @@ private[spark] object CheckpointRDD extends Logging { val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") val fs = path.getFileSystem(new Configuration()) - sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same") assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same") -- cgit v1.2.3 From b607c9e9165a996289e4fb78bf7f2792121183d0 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 3 Jan 2013 18:07:39 -0800 Subject: A very rough, early cut at some Java functionality for Streaming. --- project/SparkBuild.scala | 3 +- .../main/scala/spark/streaming/JavaAPISuite.java | 64 +++++++++++++++ .../spark/streaming/api/java/JavaDStream.scala | 95 ++++++++++++++++++++++ .../streaming/api/java/JavaStreamingContext.scala | 29 +++++++ 4 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 streaming/src/main/scala/spark/streaming/JavaAPISuite.java create mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala create mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d5cda347a4..39db4be842 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -162,7 +162,8 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", libraryDependencies ++= Seq( - "com.github.sgroschupf" % "zkclient" % "0.1") + "com.github.sgroschupf" % "zkclient" % "0.1", + "junit" % "junit" % "4.8.1") ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( diff --git a/streaming/src/main/scala/spark/streaming/JavaAPISuite.java b/streaming/src/main/scala/spark/streaming/JavaAPISuite.java new file mode 100644 index 0000000000..bcaaa4fa80 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/JavaAPISuite.java @@ -0,0 +1,64 @@ +package spark.streaming; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import spark.api.java.JavaRDD; +import spark.api.java.function.Function; +import spark.api.java.function.Function2; +import spark.streaming.api.java.JavaStreamingContext; + +import java.io.Serializable; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaStreamingContext sc; + + @Before + public void setUp() { + sc = new JavaStreamingContext("local[2]", "test", new Time(1000)); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port"); + } + + @Test + public void simpleTest() { + sc.textFileStream("/var/log/syslog").print(); + sc.start(); + } + + public static void main(String[] args) { + JavaStreamingContext sc = new JavaStreamingContext("local[2]", "test", new Time(1000)); + + sc.networkTextStream("localhost", 12345).map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }).reduce(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }).foreach(new Function2, Time, Void>() { + @Override + public Void call(JavaRDD integerJavaRDD, Time t) throws Exception { + System.out.println("Contents @ " + t.toFormattedString()); + for (int i: integerJavaRDD.collect()) { + System.out.println(i + "\n"); + } + return null; + } + }); + + sc.start(); + } +} diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala new file mode 100644 index 0000000000..e9391642f8 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -0,0 +1,95 @@ +package spark.streaming.api.java + +import java.util.{List => JList} + +import scala.collection.JavaConversions._ + +import spark.streaming._ +import spark.api.java.JavaRDD +import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} +import java.util +import spark.RDD + +class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) { + def print() = dstream.print() + + // TODO move to type-specific implementations + def cache() : JavaDStream[T] = { + dstream.cache() + } + + def count() : JavaDStream[Int] = { + dstream.count() + } + + def countByWindow(windowTime: Time, slideTime: Time) : JavaDStream[Int] = { + dstream.countByWindow(windowTime, slideTime) + } + + def compute(validTime: Time): JavaRDD[T] = { + dstream.compute(validTime) match { + case Some(rdd) => new JavaRDD(rdd) + case None => null + } + } + + def context(): StreamingContext = dstream.context() + + def window(windowTime: Time) = { + dstream.window(windowTime) + } + + def window(windowTime: Time, slideTime: Time): JavaDStream[T] = { + dstream.window(windowTime, slideTime) + } + + def tumble(batchTime: Time): JavaDStream[T] = { + dstream.tumble(batchTime) + } + + def map[R](f: JFunction[T, R]): JavaDStream[R] = { + new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType()) + } + + def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = { + dstream.filter((x => f(x).booleanValue())) + } + + def glom(): JavaDStream[JList[T]] = { + new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + } + + // TODO: Other map partitions + def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + new JavaDStream(dstream.mapPartitions(fn)(f.elementType()))(f.elementType()) + } + + def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f) + + def reduceByWindow( + reduceFunc: JFunction2[T, T, T], + invReduceFunc: JFunction2[T, T, T], + windowTime: Time, + slideTime: Time): JavaDStream[T] = { + dstream.reduceByWindow(reduceFunc, invReduceFunc, windowTime, slideTime) + } + + def slice(fromTime: Time, toTime: Time): JList[JavaRDD[T]] = { + new util.ArrayList(dstream.slice(fromTime, toTime).map(new JavaRDD(_)).toSeq) + } + + def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) = { + dstream.foreach(rdd => foreachFunc.call(new JavaRDD(rdd))) + } + + def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) = { + dstream.foreach((rdd, time) => foreachFunc.call(new JavaRDD(rdd), time)) + } +} + +object JavaDStream { + implicit def fromDStream[T: ClassManifest](dstream: DStream[T]): JavaDStream[T] = + new JavaDStream[T](dstream) + +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala new file mode 100644 index 0000000000..46f8cffd0b --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -0,0 +1,29 @@ +package spark.streaming.api.java + +import scala.collection.JavaConversions._ + +import spark.streaming._ +import dstream.SparkFlumeEvent +import spark.storage.StorageLevel + +class JavaStreamingContext(val ssc: StreamingContext) { + def this(master: String, frameworkName: String, batchDuration: Time) = + this(new StreamingContext(master, frameworkName, batchDuration)) + + def textFileStream(directory: String): JavaDStream[String] = { + ssc.textFileStream(directory) + } + + def networkTextStream(hostname: String, port: Int): JavaDStream[String] = { + ssc.networkTextStream(hostname, port) + } + + def flumeStream(hostname: String, port: Int, storageLevel: StorageLevel): + JavaDStream[SparkFlumeEvent] = { + ssc.flumeStream(hostname, port, storageLevel) + } + + def start() = ssc.start() + def stop() = ssc.stop() + +} -- cgit v1.2.3 From 867a7455e27af9e8a6b95c87c882c0eebcaed0ad Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 4 Jan 2013 11:19:20 -0800 Subject: Adding some initial tests to streaming API. --- .../main/scala/spark/streaming/JavaAPISuite.java | 64 -------------- .../streaming/api/java/JavaStreamingContext.scala | 1 + streaming/src/test/scala/JavaTestUtils.scala | 42 ++++++++++ .../test/scala/spark/streaming/JavaAPISuite.java | 97 ++++++++++++++++++++++ 4 files changed, 140 insertions(+), 64 deletions(-) delete mode 100644 streaming/src/main/scala/spark/streaming/JavaAPISuite.java create mode 100644 streaming/src/test/scala/JavaTestUtils.scala create mode 100644 streaming/src/test/scala/spark/streaming/JavaAPISuite.java diff --git a/streaming/src/main/scala/spark/streaming/JavaAPISuite.java b/streaming/src/main/scala/spark/streaming/JavaAPISuite.java deleted file mode 100644 index bcaaa4fa80..0000000000 --- a/streaming/src/main/scala/spark/streaming/JavaAPISuite.java +++ /dev/null @@ -1,64 +0,0 @@ -package spark.streaming; - -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import spark.api.java.JavaRDD; -import spark.api.java.function.Function; -import spark.api.java.function.Function2; -import spark.streaming.api.java.JavaStreamingContext; - -import java.io.Serializable; - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { - private transient JavaStreamingContext sc; - - @Before - public void setUp() { - sc = new JavaStreamingContext("local[2]", "test", new Time(1000)); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port"); - } - - @Test - public void simpleTest() { - sc.textFileStream("/var/log/syslog").print(); - sc.start(); - } - - public static void main(String[] args) { - JavaStreamingContext sc = new JavaStreamingContext("local[2]", "test", new Time(1000)); - - sc.networkTextStream("localhost", 12345).map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }).reduce(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; - } - }).foreach(new Function2, Time, Void>() { - @Override - public Void call(JavaRDD integerJavaRDD, Time t) throws Exception { - System.out.println("Contents @ " + t.toFormattedString()); - for (int i: integerJavaRDD.collect()) { - System.out.println(i + "\n"); - } - return null; - } - }); - - sc.start(); - } -} diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 46f8cffd0b..19cd032fc1 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -1,6 +1,7 @@ package spark.streaming.api.java import scala.collection.JavaConversions._ +import java.util.{List => JList} import spark.streaming._ import dstream.SparkFlumeEvent diff --git a/streaming/src/test/scala/JavaTestUtils.scala b/streaming/src/test/scala/JavaTestUtils.scala new file mode 100644 index 0000000000..776b0e6bb6 --- /dev/null +++ b/streaming/src/test/scala/JavaTestUtils.scala @@ -0,0 +1,42 @@ +package spark.streaming + +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.{List => JList} +import spark.streaming.api.java.{JavaDStream, JavaStreamingContext} +import spark.streaming._ +import java.util.ArrayList +import collection.JavaConversions._ + +/** Exposes core test functionality in a Java-friendly way. */ +object JavaTestUtils extends TestSuiteBase { + def attachTestInputStream[T](ssc: JavaStreamingContext, + data: JList[JList[T]], numPartitions: Int) = { + val seqData = data.map(Seq(_:_*)) + + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val dstream = new TestInputStream[T](ssc.ssc, seqData, numPartitions) + ssc.ssc.registerInputStream(dstream) + new JavaDStream[T](dstream) + } + + def attachTestOutputStream[T](dstream: JavaDStream[T]) = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val ostream = new TestOutputStream(dstream.dstream, + new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]) + dstream.dstream.ssc.registerOutputStream(ostream) + } + + def runStreams[V]( + ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { + implicit val cm: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) + val out = new ArrayList[JList[V]]() + res.map(entry => out.append(new ArrayList[V](entry))) + out + } + +} + diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java new file mode 100644 index 0000000000..5327edfd5d --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -0,0 +1,97 @@ +package spark.streaming; + +import org.junit.Assert; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import spark.api.java.JavaRDD; +import spark.api.java.function.Function; +import spark.api.java.function.Function2; +import spark.streaming.JavaTestUtils; +import spark.streaming.api.java.JavaDStream; +import spark.streaming.api.java.JavaStreamingContext; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaStreamingContext sc; + + @Before + public void setUp() { + sc = new JavaStreamingContext("local[2]", "test", new Time(1000)); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port"); + } + + @Test + public void testCount() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3,4), Arrays.asList(3,4,5), Arrays.asList(3)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream count = stream.count(); + JavaTestUtils.attachTestOutputStream(count); + List> result = JavaTestUtils.runStreams(sc, 3, 3); + + Assert.assertTrue(result.equals( + Arrays.asList(Arrays.asList(4), Arrays.asList(3), Arrays.asList(1)))); + } + + @Test + public void testMap() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), Arrays.asList("goodnight", "moon")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaTestUtils.attachTestOutputStream(letterCount); + List> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertTrue(result.equals( + Arrays.asList(Arrays.asList(5, 5), Arrays.asList(9, 4)))); + } + + public static void main(String[] args) { + JavaStreamingContext sc = new JavaStreamingContext("local[2]", "test", new Time(1000)); + + sc.networkTextStream("localhost", 12345).map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }).reduce(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }).foreach(new Function2, Time, Void>() { + @Override + public Void call(JavaRDD integerJavaRDD, Time t) throws Exception { + System.out.println("Contents @ " + t.toFormattedString()); + for (int i: integerJavaRDD.collect()) { + System.out.println(i + "\n"); + } + return null; + } + }); + + sc.start(); + } +} -- cgit v1.2.3 From b0974e6c1d1d8e03d9b70660070f66977ea2e797 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 4 Jan 2013 11:21:46 -0800 Subject: Remving main method from tests --- .../test/scala/spark/streaming/JavaAPISuite.java | 27 ---------------------- 1 file changed, 27 deletions(-) diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 5327edfd5d..9cf0341cce 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -67,31 +67,4 @@ public class JavaAPISuite implements Serializable { Assert.assertTrue(result.equals( Arrays.asList(Arrays.asList(5, 5), Arrays.asList(9, 4)))); } - - public static void main(String[] args) { - JavaStreamingContext sc = new JavaStreamingContext("local[2]", "test", new Time(1000)); - - sc.networkTextStream("localhost", 12345).map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }).reduce(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; - } - }).foreach(new Function2, Time, Void>() { - @Override - public Void call(JavaRDD integerJavaRDD, Time t) throws Exception { - System.out.println("Contents @ " + t.toFormattedString()); - for (int i: integerJavaRDD.collect()) { - System.out.println(i + "\n"); - } - return null; - } - }); - - sc.start(); - } } -- cgit v1.2.3 From 91b3d414481cd06e0b75f621550a38042f9b2ffd Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 4 Jan 2013 16:13:52 -0800 Subject: Better equality test (thanks Josh) --- streaming/src/test/scala/spark/streaming/JavaAPISuite.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 9cf0341cce..13dfd921cf 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -45,8 +45,8 @@ public class JavaAPISuite implements Serializable { JavaTestUtils.attachTestOutputStream(count); List> result = JavaTestUtils.runStreams(sc, 3, 3); - Assert.assertTrue(result.equals( - Arrays.asList(Arrays.asList(4), Arrays.asList(3), Arrays.asList(1)))); + Assert.assertEquals(result, + Arrays.asList(Arrays.asList(4), Arrays.asList(3), Arrays.asList(1))); } @Test @@ -64,7 +64,6 @@ public class JavaAPISuite implements Serializable { JavaTestUtils.attachTestOutputStream(letterCount); List> result = JavaTestUtils.runStreams(sc, 2, 2); - Assert.assertTrue(result.equals( - Arrays.asList(Arrays.asList(5, 5), Arrays.asList(9, 4)))); + Assert.assertEquals(result, Arrays.asList(Arrays.asList(5, 5), Arrays.asList(9, 4))); } } -- cgit v1.2.3 From 22a8c7be9aebe46c7ee332228967039be811043b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 4 Jan 2013 18:30:48 -0800 Subject: Adding more tests --- .../spark/streaming/api/java/JavaDStream.scala | 2 +- .../test/scala/spark/streaming/JavaAPISuite.java | 181 ++++++++++++++++++++- 2 files changed, 173 insertions(+), 10 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index e9391642f8..d0fa06ba7b 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -35,7 +35,7 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM def context(): StreamingContext = dstream.context() - def window(windowTime: Time) = { + def window(windowTime: Time): JavaDStream[T] = { dstream.window(windowTime) } diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 13dfd921cf..9833478221 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -1,19 +1,20 @@ package spark.streaming; +import com.google.common.collect.Lists; import org.junit.Assert; import org.junit.After; import org.junit.Before; import org.junit.Test; -import spark.api.java.JavaRDD; +import spark.api.java.function.FlatMapFunction; import spark.api.java.function.Function; -import spark.api.java.function.Function2; import spark.streaming.JavaTestUtils; import spark.streaming.api.java.JavaDStream; import spark.streaming.api.java.JavaStreamingContext; import java.io.Serializable; -import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; import java.util.List; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -38,21 +39,31 @@ public class JavaAPISuite implements Serializable { @Test public void testCount() { List> inputData = Arrays.asList( - Arrays.asList(1,2,3,4), Arrays.asList(3,4,5), Arrays.asList(3)); + Arrays.asList(1,2,3,4), + Arrays.asList(3,4,5), + Arrays.asList(3)); + + List> expected = Arrays.asList( + Arrays.asList(4), + Arrays.asList(3), + Arrays.asList(1)); JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); JavaDStream count = stream.count(); JavaTestUtils.attachTestOutputStream(count); List> result = JavaTestUtils.runStreams(sc, 3, 3); - - Assert.assertEquals(result, - Arrays.asList(Arrays.asList(4), Arrays.asList(3), Arrays.asList(1))); + assertOrderInvariantEquals(expected, result); } @Test public void testMap() { List> inputData = Arrays.asList( - Arrays.asList("hello", "world"), Arrays.asList("goodnight", "moon")); + Arrays.asList("hello", "world"), + Arrays.asList("goodnight", "moon")); + + List> expected = Arrays.asList( + Arrays.asList(5,5), + Arrays.asList(9,4)); JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @@ -64,6 +75,158 @@ public class JavaAPISuite implements Serializable { JavaTestUtils.attachTestOutputStream(letterCount); List> result = JavaTestUtils.runStreams(sc, 2, 2); - Assert.assertEquals(result, Arrays.asList(Arrays.asList(5, 5), Arrays.asList(9, 4))); + assertOrderInvariantEquals(expected, result); } + + @Test + public void testWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6,1,2,3), + Arrays.asList(7,8,9,4,5,6), + Arrays.asList(7,8,9)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream windowedRDD = stream.window(new Time(2000)); + JavaTestUtils.attachTestOutputStream(windowedRDD); + List> result = JavaTestUtils.runStreams(sc, 4, 4); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testWindowWithSlideTime() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9), + Arrays.asList(10,11,12), + Arrays.asList(13,14,15), + Arrays.asList(16,17,18)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3,4,5,6), + Arrays.asList(1,2,3,4,5,6,7,8,9,10,11,12), + Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18), + Arrays.asList(13,14,15,16,17,18)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream windowedRDD = stream.window(new Time(4000), new Time(2000)); + JavaTestUtils.attachTestOutputStream(windowedRDD); + List> result = JavaTestUtils.runStreams(sc, 8, 4); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testTumble() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9), + Arrays.asList(10,11,12), + Arrays.asList(13,14,15), + Arrays.asList(16,17,18)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3,4,5,6), + Arrays.asList(7,8,9,10,11,12), + Arrays.asList(13,14,15,16,17,18)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream windowedRDD = stream.tumble(new Time(2000)); + JavaTestUtils.attachTestOutputStream(windowedRDD); + List> result = JavaTestUtils.runStreams(sc, 6, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("yankees")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream filtered = stream.filter(new Function() { + @Override + public Boolean call(String s) throws Exception { + return s.contains("a"); + } + }); + JavaTestUtils.attachTestOutputStream(filtered); + List> result = JavaTestUtils.runStreams(sc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testGlom() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List>> expected = Arrays.asList( + Arrays.asList(Arrays.asList("giants", "dodgers")), + Arrays.asList(Arrays.asList("yankees", "red socks"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream glommed = stream.glom(); + JavaTestUtils.attachTestOutputStream(glommed); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testMapPartitions() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("GIANTSDODGERS"), + Arrays.asList("YANKEESRED SOCKS")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream mapped = stream.mapPartitions(new FlatMapFunction, String>() { + @Override + public Iterable call(Iterator in) { + String out = ""; + while (in.hasNext()) { + out = out + in.next().toUpperCase(); + } + return Lists.newArrayList(out); + } + }); + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } + + /* + * Performs an order-invariant comparison of lists representing two RDD streams. This allows + * us to account for ordering variation within individual RDD's which occurs during windowing. + */ + public static void assertOrderInvariantEquals( + List> expected, List> actual) { + for (List list: expected) { + Collections.sort(list); + } + for (List list: actual) { + Collections.sort(list); + } + Assert.assertEquals(expected, actual); + } + } -- cgit v1.2.3 From 0d0bab25bd0dfefdd5a91d22a4e81d347d255cf3 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 4 Jan 2013 20:28:08 -0800 Subject: Reduce tests --- .../test/scala/spark/streaming/JavaAPISuite.java | 69 ++++++++++++++++++++-- 1 file changed, 63 insertions(+), 6 deletions(-) diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 9833478221..2d1b0f35f9 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -7,6 +7,7 @@ import org.junit.Before; import org.junit.Test; import spark.api.java.function.FlatMapFunction; import spark.api.java.function.Function; +import spark.api.java.function.Function2; import spark.streaming.JavaTestUtils; import spark.streaming.api.java.JavaDStream; import spark.streaming.api.java.JavaStreamingContext; @@ -92,8 +93,8 @@ public class JavaAPISuite implements Serializable { Arrays.asList(7,8,9)); JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); - JavaDStream windowedRDD = stream.window(new Time(2000)); - JavaTestUtils.attachTestOutputStream(windowedRDD); + JavaDStream windowed = stream.window(new Time(2000)); + JavaTestUtils.attachTestOutputStream(windowed); List> result = JavaTestUtils.runStreams(sc, 4, 4); assertOrderInvariantEquals(expected, result); @@ -116,8 +117,8 @@ public class JavaAPISuite implements Serializable { Arrays.asList(13,14,15,16,17,18)); JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); - JavaDStream windowedRDD = stream.window(new Time(4000), new Time(2000)); - JavaTestUtils.attachTestOutputStream(windowedRDD); + JavaDStream windowed = stream.window(new Time(4000), new Time(2000)); + JavaTestUtils.attachTestOutputStream(windowed); List> result = JavaTestUtils.runStreams(sc, 8, 4); assertOrderInvariantEquals(expected, result); @@ -139,8 +140,8 @@ public class JavaAPISuite implements Serializable { Arrays.asList(13,14,15,16,17,18)); JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); - JavaDStream windowedRDD = stream.tumble(new Time(2000)); - JavaTestUtils.attachTestOutputStream(windowedRDD); + JavaDStream windowed = stream.tumble(new Time(2000)); + JavaTestUtils.attachTestOutputStream(windowed); List> result = JavaTestUtils.runStreams(sc, 6, 3); assertOrderInvariantEquals(expected, result); @@ -214,6 +215,62 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, result); } + private class IntegerSum extends Function2 { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + } + + private class IntegerDifference extends Function2 { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 - i2; + } + } + + @Test + public void testReduce() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(15), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream reduced = stream.reduce(new IntegerSum()); + JavaTestUtils.attachTestOutputStream(reduced); + List> result = JavaTestUtils.runStreams(sc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(21), + Arrays.asList(39), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), + new IntegerDifference(), new Time(2000), new Time(1000)); + JavaTestUtils.attachTestOutputStream(reducedWindowed); + List> result = JavaTestUtils.runStreams(sc, 4, 4); + + Assert.assertEquals(expected, result); + } + /* * Performs an order-invariant comparison of lists representing two RDD streams. This allows * us to account for ordering variation within individual RDD's which occurs during windowing. -- cgit v1.2.3 From f144e0413a1e42d193a86fa04af769e2da9dc58b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 5 Jan 2013 15:06:20 -0800 Subject: Adding transform and union --- .../spark/streaming/api/java/JavaDStream.scala | 14 +++++ .../test/scala/spark/streaming/JavaAPISuite.java | 62 ++++++++++++++++++++-- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index d0fa06ba7b..56e54c719a 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -86,6 +86,20 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) = { dstream.foreach((rdd, time) => foreachFunc.call(new JavaRDD(rdd), time)) } + + def transform[U](transformFunc: JFunction[JavaRDD[T], JavaRDD[U]]): JavaDStream[U] = { + implicit val cm: ClassManifest[U] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + def scalaTransform (in: RDD[T]): RDD[U] = { + transformFunc.call(new JavaRDD[T](in)).rdd + } + dstream.transform(scalaTransform(_)) + } + // TODO: transform with time + + def union(that: JavaDStream[T]): JavaDStream[T] = { + dstream.union(that.dstream) + } } object JavaDStream { diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 2d1b0f35f9..c4629c8d97 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -5,6 +5,7 @@ import org.junit.Assert; import org.junit.After; import org.junit.Before; import org.junit.Test; +import spark.api.java.JavaRDD; import spark.api.java.function.FlatMapFunction; import spark.api.java.function.Function; import spark.api.java.function.Function2; @@ -13,10 +14,7 @@ import spark.streaming.api.java.JavaDStream; import spark.streaming.api.java.JavaStreamingContext; import java.io.Serializable; -import java.util.Arrays; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; +import java.util.*; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -271,6 +269,62 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, result); } + @Test + public void testTransform() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(3,4,5), + Arrays.asList(6,7,8), + Arrays.asList(9,10,11)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream transformed = stream.transform(new Function, JavaRDD>() { + @Override + public JavaRDD call(JavaRDD in) throws Exception { + return in.map(new Function() { + @Override + public Integer call(Integer i) throws Exception { + return i + 2; + } + }); + }}); + JavaTestUtils.attachTestOutputStream(transformed); + List> result = JavaTestUtils.runStreams(sc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testUnion() { + List> inputData1 = Arrays.asList( + Arrays.asList(1,1), + Arrays.asList(2,2), + Arrays.asList(3,3)); + + List> inputData2 = Arrays.asList( + Arrays.asList(4,4), + Arrays.asList(5,5), + Arrays.asList(6,6)); + + List> expected = Arrays.asList( + Arrays.asList(1,1,4,4), + Arrays.asList(2,2,5,5), + Arrays.asList(3,3,6,6)); + + JavaDStream stream1 = JavaTestUtils.attachTestInputStream(sc, inputData1, 2); + JavaDStream stream2 = JavaTestUtils.attachTestInputStream(sc, inputData2, 2); + + JavaDStream unioned = stream1.union(stream2); + JavaTestUtils.attachTestOutputStream(unioned); + List> result = JavaTestUtils.runStreams(sc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + /* * Performs an order-invariant comparison of lists representing two RDD streams. This allows * us to account for ordering variation within individual RDD's which occurs during windowing. -- cgit v1.2.3 From 6e514a8d3511891a3f7221c594171477a0b5a38f Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 7 Jan 2013 11:02:03 -0800 Subject: PairDStream and DStreamLike --- .../spark/streaming/api/java/JavaDStream.scala | 102 +--------------- .../spark/streaming/api/java/JavaDStreamLike.scala | 109 +++++++++++++++++ .../spark/streaming/api/java/JavaPairDStream.scala | 134 +++++++++++++++++++++ streaming/src/test/scala/JavaTestUtils.scala | 6 +- .../test/scala/spark/streaming/JavaAPISuite.java | 108 +++++++++++++++++ 5 files changed, 359 insertions(+), 100 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala create mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index 56e54c719a..9e2823d81f 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -1,109 +1,17 @@ package spark.streaming.api.java -import java.util.{List => JList} +import spark.streaming.DStream +import spark.api.java.function.{Function => JFunction} -import scala.collection.JavaConversions._ - -import spark.streaming._ -import spark.api.java.JavaRDD -import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} -import java.util -import spark.RDD - -class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) { - def print() = dstream.print() - - // TODO move to type-specific implementations - def cache() : JavaDStream[T] = { - dstream.cache() - } - - def count() : JavaDStream[Int] = { - dstream.count() - } - - def countByWindow(windowTime: Time, slideTime: Time) : JavaDStream[Int] = { - dstream.countByWindow(windowTime, slideTime) - } - - def compute(validTime: Time): JavaRDD[T] = { - dstream.compute(validTime) match { - case Some(rdd) => new JavaRDD(rdd) - case None => null - } - } - - def context(): StreamingContext = dstream.context() - - def window(windowTime: Time): JavaDStream[T] = { - dstream.window(windowTime) - } - - def window(windowTime: Time, slideTime: Time): JavaDStream[T] = { - dstream.window(windowTime, slideTime) - } - - def tumble(batchTime: Time): JavaDStream[T] = { - dstream.tumble(batchTime) - } - - def map[R](f: JFunction[T, R]): JavaDStream[R] = { - new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType()) - } +class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) + extends JavaDStreamLike[T, JavaDStream[T]] { def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = { dstream.filter((x => f(x).booleanValue())) } - - def glom(): JavaDStream[JList[T]] = { - new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - } - - // TODO: Other map partitions - def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - new JavaDStream(dstream.mapPartitions(fn)(f.elementType()))(f.elementType()) - } - - def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f) - - def reduceByWindow( - reduceFunc: JFunction2[T, T, T], - invReduceFunc: JFunction2[T, T, T], - windowTime: Time, - slideTime: Time): JavaDStream[T] = { - dstream.reduceByWindow(reduceFunc, invReduceFunc, windowTime, slideTime) - } - - def slice(fromTime: Time, toTime: Time): JList[JavaRDD[T]] = { - new util.ArrayList(dstream.slice(fromTime, toTime).map(new JavaRDD(_)).toSeq) - } - - def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) = { - dstream.foreach(rdd => foreachFunc.call(new JavaRDD(rdd))) - } - - def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) = { - dstream.foreach((rdd, time) => foreachFunc.call(new JavaRDD(rdd), time)) - } - - def transform[U](transformFunc: JFunction[JavaRDD[T], JavaRDD[U]]): JavaDStream[U] = { - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - def scalaTransform (in: RDD[T]): RDD[U] = { - transformFunc.call(new JavaRDD[T](in)).rdd - } - dstream.transform(scalaTransform(_)) - } - // TODO: transform with time - - def union(that: JavaDStream[T]): JavaDStream[T] = { - dstream.union(that.dstream) - } } object JavaDStream { implicit def fromDStream[T: ClassManifest](dstream: DStream[T]): JavaDStream[T] = new JavaDStream[T](dstream) - -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala new file mode 100644 index 0000000000..daea56f50c --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -0,0 +1,109 @@ +package spark.streaming.api.java + +import java.util.{List => JList} + +import scala.collection.JavaConversions._ + +import spark.streaming._ +import spark.api.java.JavaRDD +import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} +import java.util +import spark.RDD +import JavaDStream._ + +trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable { + implicit val classManifest: ClassManifest[T] + + def dstream: DStream[T] + + def print() = dstream.print() + + // TODO move to type-specific implementations + def cache() : JavaDStream[T] = { + dstream.cache() + } + + def count() : JavaDStream[Int] = { + dstream.count() + } + + def countByWindow(windowTime: Time, slideTime: Time) : JavaDStream[Int] = { + dstream.countByWindow(windowTime, slideTime) + } + + def compute(validTime: Time): JavaRDD[T] = { + dstream.compute(validTime) match { + case Some(rdd) => new JavaRDD(rdd) + case None => null + } + } + + def context(): StreamingContext = dstream.context() + + def window(windowTime: Time): JavaDStream[T] = { + dstream.window(windowTime) + } + + def window(windowTime: Time, slideTime: Time): JavaDStream[T] = { + dstream.window(windowTime, slideTime) + } + + def tumble(batchTime: Time): JavaDStream[T] = { + dstream.tumble(batchTime) + } + + def map[R](f: JFunction[T, R]): JavaDStream[R] = { + new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType()) + } + + def map[K, V](f: PairFunction[T, K, V]): JavaPairDStream[K, V] = { + def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] + new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType()) + } + + def glom(): JavaDStream[JList[T]] = { + new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + } + + // TODO: Other map partitions + def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + new JavaDStream(dstream.mapPartitions(fn)(f.elementType()))(f.elementType()) + } + + def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f) + + def reduceByWindow( + reduceFunc: JFunction2[T, T, T], + invReduceFunc: JFunction2[T, T, T], + windowTime: Time, + slideTime: Time): JavaDStream[T] = { + dstream.reduceByWindow(reduceFunc, invReduceFunc, windowTime, slideTime) + } + + def slice(fromTime: Time, toTime: Time): JList[JavaRDD[T]] = { + new util.ArrayList(dstream.slice(fromTime, toTime).map(new JavaRDD(_)).toSeq) + } + + def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) = { + dstream.foreach(rdd => foreachFunc.call(new JavaRDD(rdd))) + } + + def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) = { + dstream.foreach((rdd, time) => foreachFunc.call(new JavaRDD(rdd), time)) + } + + def transform[U](transformFunc: JFunction[JavaRDD[T], JavaRDD[U]]): JavaDStream[U] = { + implicit val cm: ClassManifest[U] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + def scalaTransform (in: RDD[T]): RDD[U] = { + transformFunc.call(new JavaRDD[T](in)).rdd + } + dstream.transform(scalaTransform(_)) + } + // TODO: transform with time + + def union(that: JavaDStream[T]): JavaDStream[T] = { + dstream.union(that.dstream) + } +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala new file mode 100644 index 0000000000..01dda24fde --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -0,0 +1,134 @@ +package spark.streaming.api.java + +import java.util.{List => JList} + +import scala.collection.JavaConversions._ + +import spark.streaming._ +import spark.streaming.StreamingContext._ +import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import spark.Partitioner + +class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( + implicit val kManifiest: ClassManifest[K], + implicit val vManifest: ClassManifest[V]) + extends JavaDStreamLike[(K, V), JavaPairDStream[K, V]] { + + def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] = { + dstream.filter((x => f(x).booleanValue())) + } + + def groupByKey(): JavaPairDStream[K, JList[V]] = { + dstream.groupByKey().mapValues(seqAsJavaList _) + } + + def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = { + dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _) + } + + def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JList[V]] = { + dstream.groupByKey(partitioner).mapValues(seqAsJavaList _) + } + + def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] = { + dstream.reduceByKey(func) + } + + def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairDStream[K, V] = { + dstream.reduceByKey(func, numPartitions) + } + + // TODO: TEST BELOW + def combineByKey[C](createCombiner: Function[V, C], + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner): JavaPairDStream[K, C] = { + implicit val cm: ClassManifest[C] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] + dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner) + } + + def countByKey(numPartitions: Int): JavaPairDStream[K, Long] = { + dstream.countByKey(numPartitions); + } + + def countByKey(): JavaPairDStream[K, Long] = { + dstream.countByKey(); + } + + def groupByKeyAndWindow(windowTime: Time, slideTime: Time): JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowTime, slideTime).mapValues(seqAsJavaList _) + } + + def groupByKeyAndWindow(windowTime: Time, slideTime: Time, numPartitions: Int): + JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowTime, slideTime, numPartitions).mapValues(seqAsJavaList _) + } + + def groupByKeyAndWindow(windowTime: Time, slideTime: Time, partitioner: Partitioner): + JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowTime, slideTime, partitioner).mapValues(seqAsJavaList _) + } + + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time): + JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, windowTime) + } + + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time, slideTime: Time): + JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, windowTime, slideTime) + } + + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time, slideTime: Time, + numPartitions: Int): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, numPartitions) + } + + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time, slideTime: Time, + partitioner: Partitioner): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, partitioner) + } + + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], + windowTime: Time, slideTime: Time): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime) + } + + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], + windowTime: Time, slideTime: Time, numPartitions: Int): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, numPartitions) + } + + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], + windowTime: Time, slideTime: Time, partitioner: Partitioner) + : JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, partitioner) + } + + def countByKeyAndWindow(windowTime: Time, slideTime: Time): JavaPairDStream[K, Long] = { + dstream.countByKeyAndWindow(windowTime, slideTime) + } + + def countByKeyAndWindow(windowTime: Time, slideTime: Time, numPartitions: Int) + : JavaPairDStream[K, Long] = { + dstream.countByKeyAndWindow(windowTime, slideTime, numPartitions) + } + + override val classManifest: ClassManifest[(K, V)] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] +} + +object JavaPairDStream { + implicit def fromPairDStream[K: ClassManifest, V: ClassManifest](dstream: DStream[(K, V)]): + JavaPairDStream[K, V] = + new JavaPairDStream[K, V](dstream) + + def fromJavaDStream[K, V](dstream: JavaDStream[(K, V)]): JavaPairDStream[K, V] = { + implicit val cmk: ClassManifest[K] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + implicit val cmv: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + new JavaPairDStream[K, V](dstream.dstream) + } +} diff --git a/streaming/src/test/scala/JavaTestUtils.scala b/streaming/src/test/scala/JavaTestUtils.scala index 776b0e6bb6..9f3a80df8b 100644 --- a/streaming/src/test/scala/JavaTestUtils.scala +++ b/streaming/src/test/scala/JavaTestUtils.scala @@ -2,7 +2,7 @@ package spark.streaming import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import java.util.{List => JList} -import spark.streaming.api.java.{JavaDStream, JavaStreamingContext} +import api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} import spark.streaming._ import java.util.ArrayList import collection.JavaConversions._ @@ -20,7 +20,8 @@ object JavaTestUtils extends TestSuiteBase { new JavaDStream[T](dstream) } - def attachTestOutputStream[T](dstream: JavaDStream[T]) = { + def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]] + (dstream: JavaDStreamLike[T, This]) = { implicit val cm: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] val ostream = new TestOutputStream(dstream.dstream, @@ -37,6 +38,5 @@ object JavaTestUtils extends TestSuiteBase { res.map(entry => out.append(new ArrayList[V](entry))) out } - } diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index c4629c8d97..c1373e6275 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -5,12 +5,15 @@ import org.junit.Assert; import org.junit.After; import org.junit.Before; import org.junit.Test; +import scala.Tuple2; import spark.api.java.JavaRDD; import spark.api.java.function.FlatMapFunction; import spark.api.java.function.Function; import spark.api.java.function.Function2; +import spark.api.java.function.PairFunction; import spark.streaming.JavaTestUtils; import spark.streaming.api.java.JavaDStream; +import spark.streaming.api.java.JavaPairDStream; import spark.streaming.api.java.JavaStreamingContext; import java.io.Serializable; @@ -340,4 +343,109 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, actual); } + + // PairDStream Functions + @Test + public void testPairFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("giants", 6)), + Arrays.asList(new Tuple2("yankees", 7))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaPairDStream pairStream = stream.map( + new PairFunction() { + @Override + public Tuple2 call(String in) throws Exception { + return new Tuple2(in, in.length()); + } + }); + + JavaPairDStream filtered = pairStream.filter( + new Function, Boolean>() { + @Override + public Boolean call(Tuple2 in) throws Exception { + return in._1().contains("a"); + } + }); + JavaTestUtils.attachTestOutputStream(filtered); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairGroupByKey() { + List>> inputData = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("california", "giants"), + new Tuple2("new york", "yankees"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("california", "ducks"), + new Tuple2("new york", "rangers"), + new Tuple2("new york", "islanders"))); + + + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", Arrays.asList("dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + Arrays.asList( + new Tuple2>("california", Arrays.asList("sharks", "ducks")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream> grouped = pairStream.groupByKey(); + JavaTestUtils.attachTestOutputStream(grouped); + List>>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairReduceByKey() { + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2("california", 1), + new Tuple2("california", 3), + new Tuple2("new york", 4), + new Tuple2("new york", 1)), + Arrays.asList( + new Tuple2("california", 5), + new Tuple2("california", 5), + new Tuple2("new york", 3), + new Tuple2("new york", 1))); + + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList( + new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduced = pairStream.reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }); + + JavaTestUtils.attachTestOutputStream(reduced); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } } -- cgit v1.2.3 From 056f5efc557a8fcb8871d5abbee082b6398ba78c Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 7 Jan 2013 11:26:38 -0800 Subject: More pair functions --- .../spark/streaming/api/java/JavaPairDStream.scala | 43 +++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 01dda24fde..cb80a2f3e7 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -6,7 +6,7 @@ import scala.collection.JavaConversions._ import spark.streaming._ import spark.streaming.StreamingContext._ -import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import spark.api.java.function.{Function => JFunction, Function2 => JFunction2, FlatMapFunction} import spark.Partitioner class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( @@ -115,6 +115,47 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.countByKeyAndWindow(windowTime, slideTime, numPartitions) } + def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { + implicit val cm: ClassManifest[U] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + dstream.mapValues(f) + } + + def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { + import scala.collection.JavaConverters._ + def fn = (x: V) => f.apply(x).asScala + implicit val cm: ClassManifest[U] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + dstream.flatMapValues(fn) + } + + def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = { + implicit val cm: ClassManifest[W] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] + dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) + } + + def cogroup[W](other: JavaPairDStream[K, W], partitioner: Partitioner) + : JavaPairDStream[K, (JList[V], JList[W])] = { + implicit val cm: ClassManifest[W] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] + dstream.cogroup(other.dstream, partitioner) + .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) + } + + def join[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, W)] = { + implicit val cm: ClassManifest[W] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] + dstream.join(other.dstream) + } + + def join[W](other: JavaPairDStream[K, W], partitioner: Partitioner) + : JavaPairDStream[K, (V, W)] = { + implicit val cm: ClassManifest[W] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] + dstream.join(other.dstream, partitioner) + } + override val classManifest: ClassManifest[(K, V)] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] } -- cgit v1.2.3 From 74182010a4916c5b03ec74c54c21d89bcab36723 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 7 Jan 2013 20:19:57 -0800 Subject: Style cleanup and moving functions --- .../spark/streaming/api/java/JavaDStream.scala | 26 ++++- .../spark/streaming/api/java/JavaDStreamLike.scala | 41 +------ .../spark/streaming/api/java/JavaPairDStream.scala | 128 ++++++++++++++++----- 3 files changed, 128 insertions(+), 67 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index 9e2823d81f..9bf595e0bc 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -1,14 +1,36 @@ package spark.streaming.api.java -import spark.streaming.DStream +import spark.streaming.{Time, DStream} import spark.api.java.function.{Function => JFunction} +import spark.api.java.JavaRDD +import java.util.{List => JList} class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) extends JavaDStreamLike[T, JavaDStream[T]] { - def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = { + def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = dstream.filter((x => f(x).booleanValue())) + + def cache(): JavaDStream[T] = dstream.cache() + + def compute(validTime: Time): JavaRDD[T] = { + dstream.compute(validTime) match { + case Some(rdd) => new JavaRDD(rdd) + case None => null + } } + + def window(windowTime: Time): JavaDStream[T] = + dstream.window(windowTime) + + def window(windowTime: Time, slideTime: Time): JavaDStream[T] = + dstream.window(windowTime, slideTime) + + def tumble(batchTime: Time): JavaDStream[T] = + dstream.tumble(batchTime) + + def union(that: JavaDStream[T]): JavaDStream[T] = + dstream.union(that.dstream) } object JavaDStream { diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index daea56f50c..b11859ceaf 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -18,40 +18,17 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable def print() = dstream.print() - // TODO move to type-specific implementations - def cache() : JavaDStream[T] = { - dstream.cache() - } - - def count() : JavaDStream[Int] = { - dstream.count() - } + def count(): JavaDStream[Int] = dstream.count() def countByWindow(windowTime: Time, slideTime: Time) : JavaDStream[Int] = { dstream.countByWindow(windowTime, slideTime) } - def compute(validTime: Time): JavaRDD[T] = { - dstream.compute(validTime) match { - case Some(rdd) => new JavaRDD(rdd) - case None => null - } - } + def glom(): JavaDStream[JList[T]] = + new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) def context(): StreamingContext = dstream.context() - def window(windowTime: Time): JavaDStream[T] = { - dstream.window(windowTime) - } - - def window(windowTime: Time, slideTime: Time): JavaDStream[T] = { - dstream.window(windowTime, slideTime) - } - - def tumble(batchTime: Time): JavaDStream[T] = { - dstream.tumble(batchTime) - } - def map[R](f: JFunction[T, R]): JavaDStream[R] = { new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType()) } @@ -61,10 +38,6 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType()) } - def glom(): JavaDStream[JList[T]] = { - new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - } - // TODO: Other map partitions def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) @@ -85,11 +58,11 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable new util.ArrayList(dstream.slice(fromTime, toTime).map(new JavaRDD(_)).toSeq) } - def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) = { + def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) { dstream.foreach(rdd => foreachFunc.call(new JavaRDD(rdd))) } - def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) = { + def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) { dstream.foreach((rdd, time) => foreachFunc.call(new JavaRDD(rdd), time)) } @@ -102,8 +75,4 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable dstream.transform(scalaTransform(_)) } // TODO: transform with time - - def union(that: JavaDStream[T]): JavaDStream[T] = { - dstream.union(that.dstream) - } } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index cb80a2f3e7..f6dfbb2345 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -6,43 +6,64 @@ import scala.collection.JavaConversions._ import spark.streaming._ import spark.streaming.StreamingContext._ -import spark.api.java.function.{Function => JFunction, Function2 => JFunction2, FlatMapFunction} +import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import spark.Partitioner +import org.apache.hadoop.mapred.{JobConf, OutputFormat} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.hadoop.conf.Configuration +import spark.api.java.{JavaPairRDD, JavaRDD} class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifiest: ClassManifest[K], implicit val vManifest: ClassManifest[V]) extends JavaDStreamLike[(K, V), JavaPairDStream[K, V]] { - def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] = { + // Common to all DStream's + def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] = dstream.filter((x => f(x).booleanValue())) + + def cache(): JavaPairDStream[K, V] = dstream.cache() + + def compute(validTime: Time): JavaPairRDD[K, V] = { + dstream.compute(validTime) match { + case Some(rdd) => new JavaPairRDD(rdd) + case None => null + } } - def groupByKey(): JavaPairDStream[K, JList[V]] = { + def window(windowTime: Time): JavaPairDStream[K, V] = + dstream.window(windowTime) + + def window(windowTime: Time, slideTime: Time): JavaPairDStream[K, V] = + dstream.window(windowTime, slideTime) + + def tumble(batchTime: Time): JavaPairDStream[K, V] = + dstream.tumble(batchTime) + + def union(that: JavaPairDStream[K, V]): JavaPairDStream[K, V] = + dstream.union(that.dstream) + + // Only for PairDStreams... + def groupByKey(): JavaPairDStream[K, JList[V]] = dstream.groupByKey().mapValues(seqAsJavaList _) - } - def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = { + def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _) - } - def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JList[V]] = { + def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JList[V]] = dstream.groupByKey(partitioner).mapValues(seqAsJavaList _) - } - def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] = { + def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] = dstream.reduceByKey(func) - } - def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairDStream[K, V] = { + def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairDStream[K, V] = dstream.reduceByKey(func, numPartitions) - } // TODO: TEST BELOW def combineByKey[C](createCombiner: Function[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - partitioner: Partitioner): JavaPairDStream[K, C] = { + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner): JavaPairDStream[K, C] = { implicit val cm: ClassManifest[C] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner) @@ -60,28 +81,31 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.groupByKeyAndWindow(windowTime, slideTime).mapValues(seqAsJavaList _) } - def groupByKeyAndWindow(windowTime: Time, slideTime: Time, numPartitions: Int): - JavaPairDStream[K, JList[V]] = { + def groupByKeyAndWindow(windowTime: Time, slideTime: Time, numPartitions: Int) + :JavaPairDStream[K, JList[V]] = { dstream.groupByKeyAndWindow(windowTime, slideTime, numPartitions).mapValues(seqAsJavaList _) } - def groupByKeyAndWindow(windowTime: Time, slideTime: Time, partitioner: Partitioner): - JavaPairDStream[K, JList[V]] = { + def groupByKeyAndWindow(windowTime: Time, slideTime: Time, partitioner: Partitioner) + :JavaPairDStream[K, JList[V]] = { dstream.groupByKeyAndWindow(windowTime, slideTime, partitioner).mapValues(seqAsJavaList _) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time): - JavaPairDStream[K, V] = { + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time) + :JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowTime) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time, slideTime: Time): - JavaPairDStream[K, V] = { + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time, slideTime: Time) + :JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowTime, slideTime) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time, slideTime: Time, - numPartitions: Int): JavaPairDStream[K, V] = { + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + windowTime: Time, + slideTime: Time, + numPartitions: Int): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, numPartitions) } @@ -136,7 +160,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } def cogroup[W](other: JavaPairDStream[K, W], partitioner: Partitioner) - : JavaPairDStream[K, (JList[V], JList[W])] = { + : JavaPairDStream[K, (JList[V], JList[W])] = { implicit val cm: ClassManifest[W] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] dstream.cogroup(other.dstream, partitioner) @@ -150,19 +174,65 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } def join[W](other: JavaPairDStream[K, W], partitioner: Partitioner) - : JavaPairDStream[K, (V, W)] = { + : JavaPairDStream[K, (V, W)] = { implicit val cm: ClassManifest[W] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] dstream.join(other.dstream, partitioner) } + def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) { + dstream.saveAsHadoopFiles(prefix, suffix) + } + + def saveAsHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]]) { + dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) + } + + def saveAsHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + conf: JobConf) { + dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) + } + + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) { + dstream.saveAsNewAPIHadoopFiles(prefix, suffix) + } + + def saveAsNewAPIHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { + dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) + } + + def saveAsNewAPIHadoopFiles( + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]], + conf: Configuration = new Configuration) { + dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) + } + override val classManifest: ClassManifest[(K, V)] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] } object JavaPairDStream { - implicit def fromPairDStream[K: ClassManifest, V: ClassManifest](dstream: DStream[(K, V)]): - JavaPairDStream[K, V] = + implicit def fromPairDStream[K: ClassManifest, V: ClassManifest](dstream: DStream[(K, V)]) + :JavaPairDStream[K, V] = new JavaPairDStream[K, V](dstream) def fromJavaDStream[K, V](dstream: JavaDStream[(K, V)]): JavaPairDStream[K, V] = { -- cgit v1.2.3 From 7e1049d8f1b155a4bd742e84927c4cc83bb71cb6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 7 Jan 2013 21:13:55 -0800 Subject: Squashing a few TODOs --- .../spark/streaming/api/java/JavaDStreamLike.scala | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index b11859ceaf..05d89918b2 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -38,12 +38,17 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType()) } - // TODO: Other map partitions def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) new JavaDStream(dstream.mapPartitions(fn)(f.elementType()))(f.elementType()) } + def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]) + : JavaPairDStream[K, V] = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + new JavaPairDStream(dstream.mapPartitions(fn))(f.keyType(), f.valueType()) + } + def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f) def reduceByWindow( @@ -69,10 +74,16 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable def transform[U](transformFunc: JFunction[JavaRDD[T], JavaRDD[U]]): JavaDStream[U] = { implicit val cm: ClassManifest[U] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - def scalaTransform (in: RDD[T]): RDD[U] = { + def scalaTransform (in: RDD[T]): RDD[U] = transformFunc.call(new JavaRDD[T](in)).rdd - } dstream.transform(scalaTransform(_)) } - // TODO: transform with time + + def transform[U](transformFunc: JFunction2[JavaRDD[T], Time, JavaRDD[U]]): JavaDStream[U] = { + implicit val cm: ClassManifest[U] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + def scalaTransform (in: RDD[T], time: Time): RDD[U] = + transformFunc.call(new JavaRDD[T](in), time).rdd + dstream.transform(scalaTransform(_, _)) + } } \ No newline at end of file -- cgit v1.2.3 From 560c312c6060914c9c38cb98d3587685f10f7311 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 9 Jan 2013 19:27:32 -0800 Subject: Docs, some tests, and work on StreamingContext --- .../spark/streaming/api/java/JavaDStream.scala | 33 ++ .../spark/streaming/api/java/JavaDStreamLike.scala | 59 ++++ .../spark/streaming/api/java/JavaPairDStream.scala | 46 ++- .../streaming/api/java/JavaStreamingContext.scala | 152 ++++++++- .../test/scala/spark/streaming/JavaAPISuite.java | 359 +++++++++++++++++++-- 5 files changed, 617 insertions(+), 32 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index 9bf595e0bc..1e5c279e2c 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -4,15 +4,25 @@ import spark.streaming.{Time, DStream} import spark.api.java.function.{Function => JFunction} import spark.api.java.JavaRDD import java.util.{List => JList} +import spark.storage.StorageLevel class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) extends JavaDStreamLike[T, JavaDStream[T]] { + /** Returns a new DStream containing only the elements that satisfy a predicate. */ def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = dstream.filter((x => f(x).booleanValue())) + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): JavaDStream[T] = dstream.cache() + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + def persist(): JavaDStream[T] = dstream.cache() + + /** Persists the RDDs of this DStream with the given storage level */ + def persist(storageLevel: StorageLevel): JavaDStream[T] = dstream.persist(storageLevel) + + /** Method that generates a RDD for the given time */ def compute(validTime: Time): JavaRDD[T] = { dstream.compute(validTime) match { case Some(rdd) => new JavaRDD(rdd) @@ -20,15 +30,38 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM } } + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * The new DStream generates RDDs with the same interval as this DStream. + * @param windowTime width of the window; must be a multiple of this DStream's interval. + * @return + */ def window(windowTime: Time): JavaDStream[T] = dstream.window(windowTime) + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * @param windowTime duration (i.e., width) of the window; + * must be a multiple of this DStream's interval + * @param slideTime sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's interval + */ def window(windowTime: Time, slideTime: Time): JavaDStream[T] = dstream.window(windowTime, slideTime) + /** + * Returns a new DStream which computed based on tumbling window on this DStream. + * This is equivalent to window(batchTime, batchTime). + * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + */ def tumble(batchTime: Time): JavaDStream[T] = dstream.tumble(batchTime) + /** + * Returns a new DStream by unifying data of another DStream with this DStream. + * @param that Another DStream having the same interval (i.e., slideTime) as this DStream. + */ def union(that: JavaDStream[T]): JavaDStream[T] = dstream.union(that.dstream) } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index 05d89918b2..23a0aaaefd 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -16,41 +16,81 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable def dstream: DStream[T] + /** + * Prints the first ten elements of each RDD generated in this DStream. This is an output + * operator, so this DStream will be registered as an output stream and there materialized. + */ def print() = dstream.print() + /** + * Returns a new DStream in which each RDD has a single element generated by counting each RDD + * of this DStream. + */ def count(): JavaDStream[Int] = dstream.count() + /** + * Returns a new DStream in which each RDD has a single element generated by counting the number + * of elements in a window over this DStream. windowTime and slideTime are as defined in the + * window() operation. This is equivalent to window(windowTime, slideTime).count() + */ def countByWindow(windowTime: Time, slideTime: Time) : JavaDStream[Int] = { dstream.countByWindow(windowTime, slideTime) } + /** + * Return a new DStream in which each RDD is generated by applying glom() to each RDD of + * this DStream. Applying glom() to an RDD coalesces all elements within each partition into + * an array. + */ def glom(): JavaDStream[JList[T]] = new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + /** Returns the StreamingContext associated with this DStream */ def context(): StreamingContext = dstream.context() + /** Returns a new DStream by applying a function to all elements of this DStream. */ def map[R](f: JFunction[T, R]): JavaDStream[R] = { new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType()) } + /** Returns a new DStream by applying a function to all elements of this DStream. */ def map[K, V](f: PairFunction[T, K, V]): JavaPairDStream[K, V] = { def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType()) } + /** + * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs + * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition + * of the RDD. + */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) new JavaDStream(dstream.mapPartitions(fn)(f.elementType()))(f.elementType()) } + /** + * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs + * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition + * of the RDD. + */ def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]) : JavaPairDStream[K, V] = { def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) new JavaPairDStream(dstream.mapPartitions(fn))(f.keyType(), f.valueType()) } + /** + * Returns a new DStream in which each RDD has a single element generated by reducing each RDD + * of this DStream. + */ def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f) + /** + * Returns a new DStream in which each RDD has a single element generated by reducing all + * elements in a window over this DStream. windowTime and slideTime are as defined in the + * window() operation. This is equivalent to window(windowTime, slideTime).reduce(reduceFunc) + */ def reduceByWindow( reduceFunc: JFunction2[T, T, T], invReduceFunc: JFunction2[T, T, T], @@ -59,18 +99,33 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable dstream.reduceByWindow(reduceFunc, invReduceFunc, windowTime, slideTime) } + /** + * Returns all the RDDs between 'fromTime' to 'toTime' (both included) + */ def slice(fromTime: Time, toTime: Time): JList[JavaRDD[T]] = { new util.ArrayList(dstream.slice(fromTime, toTime).map(new JavaRDD(_)).toSeq) } + /** + * Applies a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) { dstream.foreach(rdd => foreachFunc.call(new JavaRDD(rdd))) } + /** + * Applies a function to each RDD in this DStream. This is an output operator, so + * this DStream will be registered as an output stream and therefore materialized. + */ def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) { dstream.foreach((rdd, time) => foreachFunc.call(new JavaRDD(rdd), time)) } + /** + * Returns a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ def transform[U](transformFunc: JFunction[JavaRDD[T], JavaRDD[U]]): JavaDStream[U] = { implicit val cm: ClassManifest[U] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] @@ -79,6 +134,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable dstream.transform(scalaTransform(_)) } + /** + * Returns a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ def transform[U](transformFunc: JFunction2[JavaRDD[T], Time, JavaRDD[U]]): JavaDStream[U] = { implicit val cm: ClassManifest[U] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index f6dfbb2345..f36b870046 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -12,18 +12,31 @@ import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.hadoop.conf.Configuration import spark.api.java.{JavaPairRDD, JavaRDD} +import spark.storage.StorageLevel class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifiest: ClassManifest[K], implicit val vManifest: ClassManifest[V]) extends JavaDStreamLike[(K, V), JavaPairDStream[K, V]] { - // Common to all DStream's + // ======================================================================= + // Methods common to all DStream's + // ======================================================================= + + /** Returns a new DStream containing only the elements that satisfy a predicate. */ def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] = dstream.filter((x => f(x).booleanValue())) + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): JavaPairDStream[K, V] = dstream.cache() + /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + def persist(): JavaPairDStream[K, V] = dstream.cache() + + /** Persists the RDDs of this DStream with the given storage level */ + def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel) + + /** Method that generates a RDD for the given time */ def compute(validTime: Time): JavaPairRDD[K, V] = { dstream.compute(validTime) match { case Some(rdd) => new JavaPairRDD(rdd) @@ -31,19 +44,45 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } } + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * The new DStream generates RDDs with the same interval as this DStream. + * @param windowTime width of the window; must be a multiple of this DStream's interval. + * @return + */ def window(windowTime: Time): JavaPairDStream[K, V] = dstream.window(windowTime) + /** + * Return a new DStream which is computed based on windowed batches of this DStream. + * @param windowTime duration (i.e., width) of the window; + * must be a multiple of this DStream's interval + * @param slideTime sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's interval + */ def window(windowTime: Time, slideTime: Time): JavaPairDStream[K, V] = dstream.window(windowTime, slideTime) + /** + * Returns a new DStream which computed based on tumbling window on this DStream. + * This is equivalent to window(batchTime, batchTime). + * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + */ def tumble(batchTime: Time): JavaPairDStream[K, V] = dstream.tumble(batchTime) + /** + * Returns a new DStream by unifying data of another DStream with this DStream. + * @param that Another DStream having the same interval (i.e., slideTime) as this DStream. + */ def union(that: JavaPairDStream[K, V]): JavaPairDStream[K, V] = dstream.union(that.dstream) - // Only for PairDStreams... + // ======================================================================= + // Methods only for PairDStream's + // ======================================================================= + def groupByKey(): JavaPairDStream[K, JList[V]] = dstream.groupByKey().mapValues(seqAsJavaList _) @@ -59,8 +98,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairDStream[K, V] = dstream.reduceByKey(func, numPartitions) - // TODO: TEST BELOW - def combineByKey[C](createCombiner: Function[V, C], + def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C], partitioner: Partitioner): JavaPairDStream[K, C] = { diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 19cd032fc1..f96b4fbd7d 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -4,27 +4,173 @@ import scala.collection.JavaConversions._ import java.util.{List => JList} import spark.streaming._ -import dstream.SparkFlumeEvent +import dstream._ import spark.storage.StorageLevel +import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import java.io.InputStream class JavaStreamingContext(val ssc: StreamingContext) { def this(master: String, frameworkName: String, batchDuration: Time) = this(new StreamingContext(master, frameworkName, batchDuration)) - def textFileStream(directory: String): JavaDStream[String] = { - ssc.textFileStream(directory) + // TODOs: + // - Test StreamingContext functions + // - Test to/from Hadoop functions + // - Add checkpoint()/remember() + // - Support creating your own streams + // - Add Kafka Stream + + /** + * Create a input stream from network source hostname:port. Data is received using + * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited + * lines. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param storageLevel Storage level to use for storing the received objects + * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + */ + def networkTextStream(hostname: String, port: Int, storageLevel: StorageLevel) + : JavaDStream[String] = { + ssc.networkTextStream(hostname, port, storageLevel) } + /** + * Create a input stream from network source hostname:port. Data is received using + * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited + * lines. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + */ def networkTextStream(hostname: String, port: Int): JavaDStream[String] = { ssc.networkTextStream(hostname, port) } + /** + * Create a input stream from network source hostname:port. Data is received using + * a TCP socket and the receive bytes it interepreted as object using the given + * converter. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param converter Function to convert the byte stream to objects + * @param storageLevel Storage level to use for storing the received objects + * @tparam T Type of the objects received (after converting bytes to objects) + */ + def networkStream[T]( + hostname: String, + port: Int, + converter: JFunction[InputStream, java.lang.Iterable[T]], + storageLevel: StorageLevel) + : JavaDStream[T] = { + import scala.collection.JavaConverters._ + def fn = (x: InputStream) => converter.apply(x).toIterator + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.networkStream(hostname, port, fn, storageLevel) + } + + /** + * Creates a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as text files (using key as LongWritable, value + * as Text and input format as TextInputFormat). File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + */ + def textFileStream(directory: String): JavaDStream[String] = { + ssc.textFileStream(directory) + } + + /** + * Create a input stream from network source hostname:port, where data is received + * as serialized blocks (serialized using the Spark's serializer) that can be directly + * pushed into the block manager without deserializing them. This is the most efficient + * way to receive data. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @param storageLevel Storage level to use for storing the received objects + * @tparam T Type of the objects in the received blocks + */ + def rawNetworkStream[T]( + hostname: String, + port: Int, + storageLevel: StorageLevel): JavaDStream[T] = { + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + JavaDStream.fromDStream(ssc.rawNetworkStream(hostname, port, storageLevel)) + } + + /** + * Create a input stream from network source hostname:port, where data is received + * as serialized blocks (serialized using the Spark's serializer) that can be directly + * pushed into the block manager without deserializing them. This is the most efficient + * way to receive data. + * @param hostname Hostname to connect to for receiving data + * @param port Port to connect to for receiving data + * @tparam T Type of the objects in the received blocks + */ + def rawNetworkStream[T](hostname: String, port: Int): JavaDStream[T] = { + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + JavaDStream.fromDStream(ssc.rawNetworkStream(hostname, port)) + } + + /** + * Creates a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[K, V, F <: NewInputFormat[K, V]](directory: String): JavaPairDStream[K, V] = { + implicit val cmk: ClassManifest[K] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + implicit val cmv: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + implicit val cmf: ClassManifest[F] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[F]] + ssc.fileStream[K, V, F](directory); + } + + /** + * Creates a input stream from a Flume source. + * @param hostname Hostname of the slave machine to which the flume data will be sent + * @param port Port of the slave machine to which the flume data will be sent + * @param storageLevel Storage level to use for storing the received objects + */ def flumeStream(hostname: String, port: Int, storageLevel: StorageLevel): JavaDStream[SparkFlumeEvent] = { ssc.flumeStream(hostname, port, storageLevel) } + + /** + * Creates a input stream from a Flume source. + * @param hostname Hostname of the slave machine to which the flume data will be sent + * @param port Port of the slave machine to which the flume data will be sent + */ + def flumeStream(hostname: String, port: Int): + JavaDStream[SparkFlumeEvent] = { + ssc.flumeStream(hostname, port) + } + + // NOT SUPPORTED: registerInputStream + + /** + * Registers an output stream that will be computed every interval + */ + def registerOutputStream(outputStream: JavaDStreamLike[_, _]) { + ssc.registerOutputStream(outputStream.dstream) + } + + /** + * Starts the execution of the streams. + */ def start() = ssc.start() + + /** + * Sstops the execution of the streams. + */ def stop() = ssc.stop() } diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index c1373e6275..fa3a5801dd 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -6,6 +6,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; import scala.Tuple2; +import spark.HashPartitioner; import spark.api.java.JavaRDD; import spark.api.java.function.FlatMapFunction; import spark.api.java.function.Function; @@ -377,18 +378,31 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, result); } + List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("california", "giants"), + new Tuple2("new york", "yankees"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("california", "ducks"), + new Tuple2("new york", "rangers"), + new Tuple2("new york", "islanders"))); + + List>> stringIntKVStream = Arrays.asList( + Arrays.asList( + new Tuple2("california", 1), + new Tuple2("california", 3), + new Tuple2("new york", 4), + new Tuple2("new york", 1)), + Arrays.asList( + new Tuple2("california", 5), + new Tuple2("california", 5), + new Tuple2("new york", 3), + new Tuple2("new york", 1))); + @Test public void testPairGroupByKey() { - List>> inputData = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); - + List>> inputData = stringStringKVStream; List>>> expected = Arrays.asList( Arrays.asList( @@ -410,18 +424,31 @@ public class JavaAPISuite implements Serializable { @Test public void testPairReduceByKey() { - List>> inputData = Arrays.asList( + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), + new Tuple2("california", 4), + new Tuple2("new york", 5)), Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); + new Tuple2("california", 10), + new Tuple2("new york", 4))); + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduced = pairStream.reduceByKey(new IntegerSum()); + + JavaTestUtils.attachTestOutputStream(reduced); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCombineByKey() { + List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( Arrays.asList( @@ -435,17 +462,299 @@ public class JavaAPISuite implements Serializable { sc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reduced = pairStream.reduceByKey( - new Function2() { + JavaPairDStream combined = pairStream.combineByKey( + new Function() { @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; + public Integer call(Integer i) throws Exception { + return i; } - }); + }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); - JavaTestUtils.attachTestOutputStream(reduced); + JavaTestUtils.attachTestOutputStream(combined); List>> result = JavaTestUtils.runStreams(sc, 2, 2); Assert.assertEquals(expected, result); } + + @Test + public void testCountByKey() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L)), + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + // TODO: Below fails with compile error with ... wtf? + JavaPairDStream counted = pairStream.countByKey(); + JavaTestUtils.attachTestOutputStream(counted); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testGroupByKeyAndWindow() { + List>> inputData = stringStringKVStream; + + List>>> expected = Arrays.asList( + Arrays.asList(new Tuple2>("california", Arrays.asList("dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + Arrays.asList(new Tuple2>("california", + Arrays.asList("sharks", "ducks", "dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders", "yankees", "mets"))), + Arrays.asList(new Tuple2>("california", Arrays.asList("sharks", "ducks")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream> groupWindowed = + pairStream.groupByKeyAndWindow(new Time(2000), new Time(1000)); + JavaTestUtils.attachTestOutputStream(groupWindowed); + List>>> result = JavaTestUtils.runStreams(sc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindow() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow(new IntegerSum(), new Time(2000), new Time(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(sc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindowWithInverse() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Time(2000), new Time(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(sc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCountByKeyAndWindow() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L)), + Arrays.asList( + new Tuple2("california", 4L), + new Tuple2("new york", 4L)), + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + // TODO: Below fails with compile error with ... wtf? + JavaPairDStream counted = pairStream.countByKeyAndWindow(new Time(2000), new Time(1000)); + JavaTestUtils.attachTestOutputStream(counted); + List>> result = JavaTestUtils.runStreams(sc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "DODGERS"), + new Tuple2("california", "GIANTS"), + new Tuple2("new york", "YANKEES"), + new Tuple2("new york", "METS")), + Arrays.asList(new Tuple2("california", "SHARKS"), + new Tuple2("california", "DUCKS"), + new Tuple2("new york", "RANGERS"), + new Tuple2("new york", "ISLANDERS"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream mapped = pairStream.mapValues(new Function() { + @Override + public String call(String s) throws Exception { + return s.toUpperCase(); + } + }); + + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testFlatMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers1"), + new Tuple2("california", "dodgers2"), + new Tuple2("california", "giants1"), + new Tuple2("california", "giants2"), + new Tuple2("new york", "yankees1"), + new Tuple2("new york", "yankees2"), + new Tuple2("new york", "mets1"), + new Tuple2("new york", "mets2")), + Arrays.asList(new Tuple2("california", "sharks1"), + new Tuple2("california", "sharks2"), + new Tuple2("california", "ducks1"), + new Tuple2("california", "ducks2"), + new Tuple2("new york", "rangers1"), + new Tuple2("new york", "rangers2"), + new Tuple2("new york", "islanders1"), + new Tuple2("new york", "islanders2"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + + JavaPairDStream flatMapped = pairStream.flatMapValues( + new Function>() { + @Override + public Iterable call(String in) { + List out = new ArrayList(); + out.add(in + "1"); + out.add(in + "2"); + return out; + } + }); + + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCoGroup() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("new york", "yankees")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList(new Tuple2("california", "giants"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "ducks"), + new Tuple2("new york", "islanders"))); + + + List, List>>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2, List>>("california", + new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2, List>>("new york", + new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + Arrays.asList( + new Tuple2, List>>("california", + new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2, List>>("new york", + new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + sc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + sc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream, List>> grouped = pairStream1.cogroup(pairStream2); + JavaTestUtils.attachTestOutputStream(grouped); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testJoin() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("new york", "yankees")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList(new Tuple2("california", "giants"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "ducks"), + new Tuple2("new york", "islanders"))); + + + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", + new Tuple2("dodgers", "giants")), + new Tuple2>("new york", + new Tuple2("yankees", "mets"))), + Arrays.asList( + new Tuple2>("california", + new Tuple2("sharks", "ducks")), + new Tuple2>("new york", + new Tuple2("rangers", "islanders")))); + + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + sc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + sc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream> joined = pairStream1.join(pairStream2); + JavaTestUtils.attachTestOutputStream(joined); + List>> result = JavaTestUtils.runStreams(sc, 2, 2); + + Assert.assertEquals(expected, result); + } } -- cgit v1.2.3 From 2fe39a4468798b5b125c4c3436ee1180b3a7b470 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 9 Jan 2013 21:59:06 -0800 Subject: Some docs for the JavaTestUtils --- streaming/src/test/scala/JavaTestUtils.scala | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/streaming/src/test/scala/JavaTestUtils.scala b/streaming/src/test/scala/JavaTestUtils.scala index 9f3a80df8b..24ebc15e38 100644 --- a/streaming/src/test/scala/JavaTestUtils.scala +++ b/streaming/src/test/scala/JavaTestUtils.scala @@ -2,15 +2,22 @@ package spark.streaming import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import java.util.{List => JList} -import api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} +import spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} import spark.streaming._ import java.util.ArrayList import collection.JavaConversions._ -/** Exposes core test functionality in a Java-friendly way. */ +/** Exposes streaming test functionality in a Java-friendly way. */ object JavaTestUtils extends TestSuiteBase { - def attachTestInputStream[T](ssc: JavaStreamingContext, - data: JList[JList[T]], numPartitions: Int) = { + + /** + * Create a [[spark.streaming.TestInputStream]] and attach it to the supplied context. + * The stream will be derived from the supplied lists of Java objects. + **/ + def attachTestInputStream[T]( + ssc: JavaStreamingContext, + data: JList[JList[T]], + numPartitions: Int) = { val seqData = data.map(Seq(_:_*)) implicit val cm: ClassManifest[T] = @@ -20,8 +27,12 @@ object JavaTestUtils extends TestSuiteBase { new JavaDStream[T](dstream) } - def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]] - (dstream: JavaDStreamLike[T, This]) = { + /** + * Attach a provided stream to it's associated StreamingContext as a + * [[spark.streaming.TestOutputStream]]. + **/ + def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]]( + dstream: JavaDStreamLike[T, This]) = { implicit val cm: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] val ostream = new TestOutputStream(dstream.dstream, @@ -29,6 +40,11 @@ object JavaTestUtils extends TestSuiteBase { dstream.dstream.ssc.registerOutputStream(ostream) } + /** + * Process all registered streams for a numBatches batches, failing if + * numExpectedOutput RDD's are not generated. Generated RDD's are collected + * and returned, represented as a list for each batch interval. + */ def runStreams[V]( ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { implicit val cm: ClassManifest[V] = -- cgit v1.2.3 From 5004eec37c01db3b96d665b0d9606002af209eda Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 9 Jan 2013 22:02:07 -0800 Subject: Import Cleanup --- .../src/main/scala/spark/streaming/api/java/JavaPairDStream.scala | 2 +- .../main/scala/spark/streaming/api/java/JavaStreamingContext.scala | 1 - streaming/src/test/scala/spark/streaming/JavaAPISuite.java | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index f36b870046..a19a476724 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -11,7 +11,7 @@ import spark.Partitioner import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.hadoop.conf.Configuration -import spark.api.java.{JavaPairRDD, JavaRDD} +import spark.api.java.JavaPairRDD import spark.storage.StorageLevel class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index f96b4fbd7d..37ce037d5c 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -62,7 +62,6 @@ class JavaStreamingContext(val ssc: StreamingContext) { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaDStream[T] = { - import scala.collection.JavaConverters._ def fn = (x: InputStream) => converter.apply(x).toIterator implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index fa3a5801dd..9e8438d04c 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -1,8 +1,8 @@ package spark.streaming; import com.google.common.collect.Lists; -import org.junit.Assert; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import scala.Tuple2; @@ -12,10 +12,10 @@ import spark.api.java.function.FlatMapFunction; import spark.api.java.function.Function; import spark.api.java.function.Function2; import spark.api.java.function.PairFunction; -import spark.streaming.JavaTestUtils; import spark.streaming.api.java.JavaDStream; import spark.streaming.api.java.JavaPairDStream; import spark.streaming.api.java.JavaStreamingContext; +import spark.streaming.JavaTestUtils; import java.io.Serializable; import java.util.*; -- cgit v1.2.3 From b36c4f7cce53446753ecc0ce6f9bdccb12b3350b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 10 Jan 2013 19:29:22 -0800 Subject: More work on StreamingContext --- .../streaming/api/java/JavaStreamingContext.scala | 47 +++++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 37ce037d5c..e8cd03847a 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -9,6 +9,7 @@ import spark.storage.StorageLevel import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import java.io.InputStream +import java.util.{Map => JMap} class JavaStreamingContext(val ssc: StreamingContext) { def this(master: String, frameworkName: String, batchDuration: Time) = @@ -17,10 +18,31 @@ class JavaStreamingContext(val ssc: StreamingContext) { // TODOs: // - Test StreamingContext functions // - Test to/from Hadoop functions - // - Add checkpoint()/remember() - // - Support creating your own streams + // - Support registering InputStreams // - Add Kafka Stream + + /** + * Create an input stream that pulls messages form a Kafka Broker. + * @param hostname Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * By default the value is pulled from zookeper. + * @param storageLevel RDD storage level. Defaults to memory-only. + */ + def kafkaStream[T]( + hostname: String, + port: Int, + groupId: String, + topics: JMap[String, Int]) + : DStream[T] = { + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.kafkaStream(hostname, port, groupId, Map(topics.toSeq: _*)) + } /** * Create a input stream from network source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited @@ -162,6 +184,27 @@ class JavaStreamingContext(val ssc: StreamingContext) { ssc.registerOutputStream(outputStream.dstream) } + /** + * Sets the context to periodically checkpoint the DStream operations for master + * fault-tolerance. By default, the graph will be checkpointed every batch interval. + * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored + * @param interval checkpoint interval + */ + def checkpoint(directory: String, interval: Time = null) { + ssc.checkpoint(directory, interval) + } + + /** + * Sets each DStreams in this context to remember RDDs it generated in the last given duration. + * DStreams remember RDDs only for a limited duration of time and releases them for garbage + * collection. This method allows the developer to specify how to long to remember the RDDs ( + * if the developer wishes to query old data outside the DStream computation). + * @param duration Minimum duration that each DStream should remember its RDDs + */ + def remember(duration: Time) { + ssc.remember(duration) + } + /** * Starts the execution of the streams. */ -- cgit v1.2.3 From c2537057f9ed8723d2c33a1636edf9c9547cdc66 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 10 Jan 2013 19:29:33 -0800 Subject: Fixing issue with types --- .../spark/streaming/api/java/JavaPairDStream.scala | 18 ++++++++++++------ .../src/test/scala/spark/streaming/JavaAPISuite.java | 8 ++++---- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index a19a476724..fa46ca9267 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -1,6 +1,7 @@ package spark.streaming.api.java import java.util.{List => JList} +import java.lang.{Long => JLong} import scala.collection.JavaConversions._ @@ -13,6 +14,7 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.hadoop.conf.Configuration import spark.api.java.JavaPairRDD import spark.storage.StorageLevel +import java.lang class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifiest: ClassManifest[K], @@ -107,12 +109,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner) } - def countByKey(numPartitions: Int): JavaPairDStream[K, Long] = { - dstream.countByKey(numPartitions); + def countByKey(numPartitions: Int): JavaPairDStream[K, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByKey(numPartitions)); } - def countByKey(): JavaPairDStream[K, Long] = { - dstream.countByKey(); + def countByKey(): JavaPairDStream[K, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByKey()); } def groupByKeyAndWindow(windowTime: Time, slideTime: Time): JavaPairDStream[K, JList[V]] = { @@ -168,8 +170,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, partitioner) } - def countByKeyAndWindow(windowTime: Time, slideTime: Time): JavaPairDStream[K, Long] = { - dstream.countByKeyAndWindow(windowTime, slideTime) + def countByKeyAndWindow(windowTime: Time, slideTime: Time): JavaPairDStream[K, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByKeyAndWindow(windowTime, slideTime)) } def countByKeyAndWindow(windowTime: Time, slideTime: Time, numPartitions: Int) @@ -280,4 +282,8 @@ object JavaPairDStream { implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] new JavaPairDStream[K, V](dstream.dstream) } + + def scalaToJavaLong[K: ClassManifest](dstream: JavaPairDStream[K, Long]): JavaPairDStream[K, JLong] = { + StreamingContext.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_)) + } } diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 9e8438d04c..6584d861ed 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -1,6 +1,7 @@ package spark.streaming; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -492,8 +493,7 @@ public class JavaAPISuite implements Serializable { sc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - // TODO: Below fails with compile error with ... wtf? - JavaPairDStream counted = pairStream.countByKey(); + JavaPairDStream counted = pairStream.countByKey(); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(sc, 2, 2); @@ -589,8 +589,8 @@ public class JavaAPISuite implements Serializable { sc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - // TODO: Below fails with compile error with ... wtf? - JavaPairDStream counted = pairStream.countByKeyAndWindow(new Time(2000), new Time(1000)); + JavaPairDStream counted = + pairStream.countByKeyAndWindow(new Time(2000), new Time(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(sc, 3, 3); -- cgit v1.2.3 From 280b6d018691810bbb3dd3155f059132b4475995 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 10 Jan 2013 19:52:45 -0800 Subject: Porting to new Duration class --- .../spark/streaming/api/java/JavaDStream.scala | 28 ++++---- .../spark/streaming/api/java/JavaDStreamLike.scala | 24 +++---- .../spark/streaming/api/java/JavaPairDStream.scala | 76 +++++++++++----------- .../streaming/api/java/JavaStreamingContext.scala | 8 +-- .../test/scala/spark/streaming/JavaAPISuite.java | 20 +++--- 5 files changed, 78 insertions(+), 78 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index 1e5c279e2c..f85864df5d 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.api.java -import spark.streaming.{Time, DStream} +import spark.streaming.{Duration, Time, DStream} import spark.api.java.function.{Function => JFunction} import spark.api.java.JavaRDD import java.util.{List => JList} @@ -22,7 +22,7 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM /** Persists the RDDs of this DStream with the given storage level */ def persist(storageLevel: StorageLevel): JavaDStream[T] = dstream.persist(storageLevel) - /** Method that generates a RDD for the given time */ + /** Method that generates a RDD for the given duration */ def compute(validTime: Time): JavaRDD[T] = { dstream.compute(validTime) match { case Some(rdd) => new JavaRDD(rdd) @@ -33,34 +33,34 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM /** * Return a new DStream which is computed based on windowed batches of this DStream. * The new DStream generates RDDs with the same interval as this DStream. - * @param windowTime width of the window; must be a multiple of this DStream's interval. + * @param windowDuration width of the window; must be a multiple of this DStream's interval. * @return */ - def window(windowTime: Time): JavaDStream[T] = - dstream.window(windowTime) + def window(windowDuration: Duration): JavaDStream[T] = + dstream.window(windowDuration) /** * Return a new DStream which is computed based on windowed batches of this DStream. - * @param windowTime duration (i.e., width) of the window; + * @param windowDuration duration (i.e., width) of the window; * must be a multiple of this DStream's interval - * @param slideTime sliding interval of the window (i.e., the interval after which + * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's interval */ - def window(windowTime: Time, slideTime: Time): JavaDStream[T] = - dstream.window(windowTime, slideTime) + def window(windowDuration: Duration, slideDuration: Duration): JavaDStream[T] = + dstream.window(windowDuration, slideDuration) /** * Returns a new DStream which computed based on tumbling window on this DStream. - * This is equivalent to window(batchTime, batchTime). - * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + * This is equivalent to window(batchDuration, batchDuration). + * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval */ - def tumble(batchTime: Time): JavaDStream[T] = - dstream.tumble(batchTime) + def tumble(batchDuration: Duration): JavaDStream[T] = + dstream.tumble(batchDuration) /** * Returns a new DStream by unifying data of another DStream with this DStream. - * @param that Another DStream having the same interval (i.e., slideTime) as this DStream. + * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. */ def union(that: JavaDStream[T]): JavaDStream[T] = dstream.union(that.dstream) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index 23a0aaaefd..cb58c1351d 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -30,11 +30,11 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable /** * Returns a new DStream in which each RDD has a single element generated by counting the number - * of elements in a window over this DStream. windowTime and slideTime are as defined in the - * window() operation. This is equivalent to window(windowTime, slideTime).count() + * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the + * window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowTime: Time, slideTime: Time) : JavaDStream[Int] = { - dstream.countByWindow(windowTime, slideTime) + def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[Int] = { + dstream.countByWindow(windowDuration, slideDuration) } /** @@ -88,22 +88,22 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable /** * Returns a new DStream in which each RDD has a single element generated by reducing all - * elements in a window over this DStream. windowTime and slideTime are as defined in the - * window() operation. This is equivalent to window(windowTime, slideTime).reduce(reduceFunc) + * elements in a window over this DStream. windowDuration and slideDuration are as defined in the + * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc) */ def reduceByWindow( reduceFunc: JFunction2[T, T, T], invReduceFunc: JFunction2[T, T, T], - windowTime: Time, - slideTime: Time): JavaDStream[T] = { - dstream.reduceByWindow(reduceFunc, invReduceFunc, windowTime, slideTime) + windowDuration: Duration, + slideDuration: Duration): JavaDStream[T] = { + dstream.reduceByWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration) } /** - * Returns all the RDDs between 'fromTime' to 'toTime' (both included) + * Returns all the RDDs between 'fromDuration' to 'toDuration' (both included) */ - def slice(fromTime: Time, toTime: Time): JList[JavaRDD[T]] = { - new util.ArrayList(dstream.slice(fromTime, toTime).map(new JavaRDD(_)).toSeq) + def slice(fromDuration: Duration, toDuration: Duration): JList[JavaRDD[T]] = { + new util.ArrayList(dstream.slice(fromDuration, toDuration).map(new JavaRDD(_)).toSeq) } /** diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index fa46ca9267..03336d040d 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -38,7 +38,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** Persists the RDDs of this DStream with the given storage level */ def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel) - /** Method that generates a RDD for the given time */ + /** Method that generates a RDD for the given Duration */ def compute(validTime: Time): JavaPairRDD[K, V] = { dstream.compute(validTime) match { case Some(rdd) => new JavaPairRDD(rdd) @@ -49,34 +49,34 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Return a new DStream which is computed based on windowed batches of this DStream. * The new DStream generates RDDs with the same interval as this DStream. - * @param windowTime width of the window; must be a multiple of this DStream's interval. + * @param windowDuration width of the window; must be a multiple of this DStream's interval. * @return */ - def window(windowTime: Time): JavaPairDStream[K, V] = - dstream.window(windowTime) + def window(windowDuration: Duration): JavaPairDStream[K, V] = + dstream.window(windowDuration) /** * Return a new DStream which is computed based on windowed batches of this DStream. - * @param windowTime duration (i.e., width) of the window; + * @param windowDuration duration (i.e., width) of the window; * must be a multiple of this DStream's interval - * @param slideTime sliding interval of the window (i.e., the interval after which + * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's interval */ - def window(windowTime: Time, slideTime: Time): JavaPairDStream[K, V] = - dstream.window(windowTime, slideTime) + def window(windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, V] = + dstream.window(windowDuration, slideDuration) /** * Returns a new DStream which computed based on tumbling window on this DStream. - * This is equivalent to window(batchTime, batchTime). - * @param batchTime tumbling window duration; must be a multiple of this DStream's interval + * This is equivalent to window(batchDuration, batchDuration). + * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval */ - def tumble(batchTime: Time): JavaPairDStream[K, V] = - dstream.tumble(batchTime) + def tumble(batchDuration: Duration): JavaPairDStream[K, V] = + dstream.tumble(batchDuration) /** * Returns a new DStream by unifying data of another DStream with this DStream. - * @param that Another DStream having the same interval (i.e., slideTime) as this DStream. + * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. */ def union(that: JavaPairDStream[K, V]): JavaPairDStream[K, V] = dstream.union(that.dstream) @@ -117,66 +117,66 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( JavaPairDStream.scalaToJavaLong(dstream.countByKey()); } - def groupByKeyAndWindow(windowTime: Time, slideTime: Time): JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowTime, slideTime).mapValues(seqAsJavaList _) + def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(seqAsJavaList _) } - def groupByKeyAndWindow(windowTime: Time, slideTime: Time, numPartitions: Int) + def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) :JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowTime, slideTime, numPartitions).mapValues(seqAsJavaList _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions).mapValues(seqAsJavaList _) } - def groupByKeyAndWindow(windowTime: Time, slideTime: Time, partitioner: Partitioner) + def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner) :JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowTime, slideTime, partitioner).mapValues(seqAsJavaList _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner).mapValues(seqAsJavaList _) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time) + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration) :JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, windowTime) + dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time, slideTime: Time) + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration, slideDuration: Duration) :JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, windowTime, slideTime) + dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) } def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], - windowTime: Time, - slideTime: Time, + windowDuration: Duration, + slideDuration: Duration, numPartitions: Int): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, numPartitions) + dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, numPartitions) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowTime: Time, slideTime: Time, + def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, windowTime, slideTime, partitioner) + dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, partitioner) } def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], - windowTime: Time, slideTime: Time): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime) + windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration) } def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], - windowTime: Time, slideTime: Time, numPartitions: Int): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, numPartitions) + windowDuration: Duration, slideDuration: Duration, numPartitions: Int): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, numPartitions) } def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], - windowTime: Time, slideTime: Time, partitioner: Partitioner) + windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner) : JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowTime, slideTime, partitioner) + dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, partitioner) } - def countByKeyAndWindow(windowTime: Time, slideTime: Time): JavaPairDStream[K, JLong] = { - JavaPairDStream.scalaToJavaLong(dstream.countByKeyAndWindow(windowTime, slideTime)) + def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByKeyAndWindow(windowDuration, slideDuration)) } - def countByKeyAndWindow(windowTime: Time, slideTime: Time, numPartitions: Int) + def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) : JavaPairDStream[K, Long] = { - dstream.countByKeyAndWindow(windowTime, slideTime, numPartitions) + dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) } def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index e8cd03847a..5a712d18c7 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -12,7 +12,7 @@ import java.io.InputStream import java.util.{Map => JMap} class JavaStreamingContext(val ssc: StreamingContext) { - def this(master: String, frameworkName: String, batchDuration: Time) = + def this(master: String, frameworkName: String, batchDuration: Duration) = this(new StreamingContext(master, frameworkName, batchDuration)) // TODOs: @@ -190,18 +190,18 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored * @param interval checkpoint interval */ - def checkpoint(directory: String, interval: Time = null) { + def checkpoint(directory: String, interval: Duration = null) { ssc.checkpoint(directory, interval) } /** * Sets each DStreams in this context to remember RDDs it generated in the last given duration. - * DStreams remember RDDs only for a limited duration of time and releases them for garbage + * DStreams remember RDDs only for a limited duration of duration and releases them for garbage * collection. This method allows the developer to specify how to long to remember the RDDs ( * if the developer wishes to query old data outside the DStream computation). * @param duration Minimum duration that each DStream should remember its RDDs */ - def remember(duration: Time) { + def remember(duration: Duration) { ssc.remember(duration) } diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 6584d861ed..26ff5b1ccd 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -29,7 +29,7 @@ public class JavaAPISuite implements Serializable { @Before public void setUp() { - sc = new JavaStreamingContext("local[2]", "test", new Time(1000)); + sc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); } @After @@ -96,7 +96,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList(7,8,9)); JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); - JavaDStream windowed = stream.window(new Time(2000)); + JavaDStream windowed = stream.window(new Duration(2000)); JavaTestUtils.attachTestOutputStream(windowed); List> result = JavaTestUtils.runStreams(sc, 4, 4); @@ -104,7 +104,7 @@ public class JavaAPISuite implements Serializable { } @Test - public void testWindowWithSlideTime() { + public void testWindowWithSlideDuration() { List> inputData = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), @@ -120,7 +120,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList(13,14,15,16,17,18)); JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); - JavaDStream windowed = stream.window(new Time(4000), new Time(2000)); + JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000)); JavaTestUtils.attachTestOutputStream(windowed); List> result = JavaTestUtils.runStreams(sc, 8, 4); @@ -143,7 +143,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList(13,14,15,16,17,18)); JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); - JavaDStream windowed = stream.tumble(new Time(2000)); + JavaDStream windowed = stream.tumble(new Duration(2000)); JavaTestUtils.attachTestOutputStream(windowed); List> result = JavaTestUtils.runStreams(sc, 6, 3); @@ -267,7 +267,7 @@ public class JavaAPISuite implements Serializable { JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Time(2000), new Time(1000)); + new IntegerDifference(), new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(sc, 4, 4); @@ -517,7 +517,7 @@ public class JavaAPISuite implements Serializable { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream> groupWindowed = - pairStream.groupByKeyAndWindow(new Time(2000), new Time(1000)); + pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(groupWindowed); List>>> result = JavaTestUtils.runStreams(sc, 3, 3); @@ -540,7 +540,7 @@ public class JavaAPISuite implements Serializable { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow(new IntegerSum(), new Time(2000), new Time(1000)); + pairStream.reduceByKeyAndWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(sc, 3, 3); @@ -563,7 +563,7 @@ public class JavaAPISuite implements Serializable { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Time(2000), new Time(1000)); + pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(sc, 3, 3); @@ -590,7 +590,7 @@ public class JavaAPISuite implements Serializable { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream counted = - pairStream.countByKeyAndWindow(new Time(2000), new Time(1000)); + pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(sc, 3, 3); -- cgit v1.2.3 From 5bcb048167fe0b90f749910233342c09fff3fce7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 11 Jan 2013 11:18:06 -0800 Subject: More work on InputStreams --- .../spark/streaming/api/java/JavaPairDStream.scala | 3 +- .../streaming/api/java/JavaStreamingContext.scala | 68 ++++++++++++++++++--- .../test/scala/spark/streaming/JavaAPISuite.java | 69 +++++++++++++++++++++- 3 files changed, 130 insertions(+), 10 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 03336d040d..b25a3f109c 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -283,7 +283,8 @@ object JavaPairDStream { new JavaPairDStream[K, V](dstream.dstream) } - def scalaToJavaLong[K: ClassManifest](dstream: JavaPairDStream[K, Long]): JavaPairDStream[K, JLong] = { + def scalaToJavaLong[K: ClassManifest](dstream: JavaPairDStream[K, Long]) + : JavaPairDStream[K, JLong] = { StreamingContext.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_)) } } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 5a712d18c7..2833793b94 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -2,6 +2,7 @@ package spark.streaming.api.java import scala.collection.JavaConversions._ import java.util.{List => JList} +import java.lang.{Long => JLong, Integer => JInt} import spark.streaming._ import dstream._ @@ -18,9 +19,53 @@ class JavaStreamingContext(val ssc: StreamingContext) { // TODOs: // - Test StreamingContext functions // - Test to/from Hadoop functions - // - Support registering InputStreams - // - Add Kafka Stream + // - Support creating and registering InputStreams + /** + * Create an input stream that pulls messages form a Kafka Broker. + * @param hostname Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + */ + def kafkaStream[T]( + hostname: String, + port: Int, + groupId: String, + topics: JMap[String, JInt]) + : JavaDStream[T] = { + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.kafkaStream[T](hostname, port, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + } + + /** + * Create an input stream that pulls messages form a Kafka Broker. + * @param hostname Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * By default the value is pulled from zookeper. + */ + def kafkaStream[T]( + hostname: String, + port: Int, + groupId: String, + topics: JMap[String, JInt], + initialOffsets: JMap[KafkaPartitionKey, JLong]) + : JavaDStream[T] = { + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.kafkaStream[T]( + hostname, + port, + groupId, + Map(topics.mapValues(_.intValue()).toSeq: _*), + Map(initialOffsets.mapValues(_.longValue()).toSeq: _*)) + } /** * Create an input stream that pulls messages form a Kafka Broker. @@ -31,18 +76,27 @@ class JavaStreamingContext(val ssc: StreamingContext) { * in its own thread. * @param initialOffsets Optional initial offsets for each of the partitions to consume. * By default the value is pulled from zookeper. - * @param storageLevel RDD storage level. Defaults to memory-only. + * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T]( hostname: String, port: Int, groupId: String, - topics: JMap[String, Int]) - : DStream[T] = { + topics: JMap[String, JInt], + initialOffsets: JMap[KafkaPartitionKey, JLong], + storageLevel: StorageLevel) + : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream(hostname, port, groupId, Map(topics.toSeq: _*)) + ssc.kafkaStream[T]( + hostname, + port, + groupId, + Map(topics.mapValues(_.intValue()).toSeq: _*), + Map(initialOffsets.mapValues(_.longValue()).toSeq: _*), + storageLevel) } + /** * Create a input stream from network source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited @@ -175,8 +229,6 @@ class JavaStreamingContext(val ssc: StreamingContext) { ssc.flumeStream(hostname, port) } - // NOT SUPPORTED: registerInputStream - /** * Registers an output stream that will be computed every interval */ diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 26ff5b1ccd..7475b9536b 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -2,6 +2,7 @@ package spark.streaming; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -13,12 +14,15 @@ import spark.api.java.function.FlatMapFunction; import spark.api.java.function.Function; import spark.api.java.function.Function2; import spark.api.java.function.PairFunction; +import spark.storage.StorageLevel; import spark.streaming.api.java.JavaDStream; import spark.streaming.api.java.JavaPairDStream; import spark.streaming.api.java.JavaStreamingContext; import spark.streaming.JavaTestUtils; +import spark.streaming.dstream.KafkaPartitionKey; +import sun.org.mozilla.javascript.annotations.JSFunction; -import java.io.Serializable; +import java.io.*; import java.util.*; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -757,4 +761,67 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, result); } + + // Input stream tests. These mostly just test that we can instantiate a given InputStream with + // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the + // InputStream functionality is deferred to the existing Scala tests. + @Test + public void testKafkaStream() { + HashMap topics = Maps.newHashMap(); + HashMap offsets = Maps.newHashMap(); + JavaDStream test1 = sc.kafkaStream("localhost", 12345, "group", topics); + JavaDStream test2 = sc.kafkaStream("localhost", 12345, "group", topics, offsets); + JavaDStream test3 = sc.kafkaStream("localhost", 12345, "group", topics, offsets, + StorageLevel.MEMORY_AND_DISK()); + } + + @Test + public void testNetworkTextStream() { + JavaDStream test = sc.networkTextStream("localhost", 12345); + } + + @Test + public void testNetworkString() { + class Converter extends Function> { + public Iterable call(InputStream in) { + BufferedReader reader = new BufferedReader(new InputStreamReader(in)); + List out = new ArrayList(); + try { + while (true) { + String line = reader.readLine(); + if (line == null) { break; } + out.add(line); + } + } catch (IOException e) { } + return out; + } + } + + JavaDStream test = sc.networkStream( + "localhost", + 12345, + new Converter(), + StorageLevel.MEMORY_ONLY()); + } + + @Test + public void testTextFileStream() { + JavaDStream test = sc.textFileStream("/tmp/foo"); + } + + @Test + public void testRawNetworkStream() { + JavaDStream test = sc.rawNetworkStream("localhost", 12345); + } + + @Test + public void testFlumeStream() { + JavaDStream test = sc.flumeStream("localhost", 12345); + } + + @Test + public void testFileStream() { + JavaPairDStream foo = + sc.fileStream("/tmp/foo"); + } } -- cgit v1.2.3 From 3461cd99b7b680be9c9dc263382b42f30c9edd7d Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 11 Jan 2013 12:05:04 -0800 Subject: Flume example and bug fix --- .../streaming/examples/JavaFlumeEventCount.java | 50 ++++++++++++++++++++++ .../spark/streaming/api/java/JavaDStreamLike.scala | 9 +++- 2 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java diff --git a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java new file mode 100644 index 0000000000..6592d9bc2e --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java @@ -0,0 +1,50 @@ +package spark.streaming.examples; + +import spark.api.java.function.Function; +import spark.streaming.*; +import spark.streaming.api.java.*; +import spark.streaming.dstream.SparkFlumeEvent; + +/** + * Produces a count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: FlumeEventCount + * + * is a Spark master URL + * is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * is the port the Flume receiver will listen on. + */ +public class JavaFlumeEventCount { + public static void main(String[] args) { + if (args.length != 3) { + System.err.println("Usage: JavaFlumeEventCount "); + System.exit(1); + } + + String master = args[0]; + String host = args[1]; + int port = Integer.parseInt(args[2]); + + Duration batchInterval = new Duration(2000); + + JavaStreamingContext sc = new JavaStreamingContext(master, "FlumeEventCount", batchInterval); + + JavaDStream flumeStream = sc.flumeStream("localhost", port); + + flumeStream.count(); + + flumeStream.count().map(new Function() { + @Override + public String call(Integer in) { + return "Received " + in + " flume events."; + } + }).print(); + + sc.start(); + } +} diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index cb58c1351d..91bcca9afa 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -1,6 +1,7 @@ package spark.streaming.api.java import java.util.{List => JList} +import java.lang.{Integer => JInt} import scala.collection.JavaConversions._ @@ -16,6 +17,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable def dstream: DStream[T] + implicit def scalaIntToJavaInteger(in: DStream[Int]): JavaDStream[JInt] = { + in.map(new JInt(_)) + } + /** * Prints the first ten elements of each RDD generated in this DStream. This is an output * operator, so this DStream will be registered as an output stream and there materialized. @@ -26,14 +31,14 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable * Returns a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): JavaDStream[Int] = dstream.count() + def count(): JavaDStream[JInt] = dstream.count() /** * Returns a new DStream in which each RDD has a single element generated by counting the number * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the * window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[Int] = { + def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[JInt] = { dstream.countByWindow(windowDuration, slideDuration) } -- cgit v1.2.3 From a292ed8d8af069ee1318cdf7c00d3db8d3ba8db9 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 11 Jan 2013 12:34:36 -0800 Subject: Some style cleanup --- .../spark/streaming/api/java/JavaDStreamLike.scala | 9 +- .../spark/streaming/api/java/JavaPairDStream.scala | 137 +++++++++++++-------- 2 files changed, 93 insertions(+), 53 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index 91bcca9afa..80d8865725 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -97,10 +97,11 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc) */ def reduceByWindow( - reduceFunc: JFunction2[T, T, T], - invReduceFunc: JFunction2[T, T, T], - windowDuration: Duration, - slideDuration: Duration): JavaDStream[T] = { + reduceFunc: JFunction2[T, T, T], + invReduceFunc: JFunction2[T, T, T], + windowDuration: Duration, + slideDuration: Duration + ): JavaDStream[T] = { dstream.reduceByWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration) } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index b25a3f109c..eeb1f07939 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -101,9 +101,10 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKey(func, numPartitions) def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - partitioner: Partitioner): JavaPairDStream[K, C] = { + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner + ): JavaPairDStream[K, C] = { implicit val cm: ClassManifest[C] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner) @@ -117,18 +118,24 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( JavaPairDStream.scalaToJavaLong(dstream.countByKey()); } - def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, JList[V]] = { + def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) + : JavaPairDStream[K, JList[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(seqAsJavaList _) } def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) :JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions).mapValues(seqAsJavaList _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) + .mapValues(seqAsJavaList _) } - def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner) - :JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner).mapValues(seqAsJavaList _) + def groupByKeyAndWindow( + windowDuration: Duration, + slideDuration: Duration, + partitioner: Partitioner + ):JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) + .mapValues(seqAsJavaList _) } def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration) @@ -136,46 +143,78 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration, slideDuration: Duration) - :JavaPairDStream[K, V] = { + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration + ):JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) } def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - windowDuration: Duration, - slideDuration: Duration, - numPartitions: Int): JavaPairDStream[K, V] = { + reduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration, + numPartitions: Int + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, numPartitions) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration, slideDuration: Duration, - partitioner: Partitioner): JavaPairDStream[K, V] = { + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration, + partitioner: Partitioner + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, partitioner) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], - windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, V] = { + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + invReduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], - windowDuration: Duration, slideDuration: Duration, numPartitions: Int): JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, numPartitions) + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + invReduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration, + numPartitions: Int + ): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow( + reduceFunc, + invReduceFunc, + windowDuration, + slideDuration, + numPartitions) } - def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], - windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner) - : JavaPairDStream[K, V] = { - dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, partitioner) + def reduceByKeyAndWindow( + reduceFunc: Function2[V, V, V], + invReduceFunc: Function2[V, V, V], + windowDuration: Duration, + slideDuration: Duration, + partitioner: Partitioner + ): JavaPairDStream[K, V] = { + dstream.reduceByKeyAndWindow( + reduceFunc, + invReduceFunc, + windowDuration, + slideDuration, + partitioner) } - def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, JLong] = { + def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) + : JavaPairDStream[K, JLong] = { JavaPairDStream.scalaToJavaLong(dstream.countByKeyAndWindow(windowDuration, slideDuration)) } def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - : JavaPairDStream[K, Long] = { + : JavaPairDStream[K, Long] = { dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) } @@ -225,21 +264,21 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } def saveAsHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]]) { + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]]) { dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } def saveAsHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf) { + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + conf: JobConf) { dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } @@ -248,21 +287,21 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } def saveAsNewAPIHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } def saveAsNewAPIHadoopFiles( - prefix: String, - suffix: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = new Configuration) { + prefix: String, + suffix: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: NewOutputFormat[_, _]], + conf: Configuration = new Configuration) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } -- cgit v1.2.3 From d182a57cae6455804773db23d9498d2dcdd02172 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 14 Jan 2013 10:03:55 -0800 Subject: Two changes: - Updating countByX() types based on bug fix - Porting new documentation to Java --- .../streaming/examples/JavaFlumeEventCount.java | 4 +- .../spark/streaming/api/java/JavaDStreamLike.scala | 10 +- .../spark/streaming/api/java/JavaPairDStream.scala | 264 +++++++++++++++++++++ 3 files changed, 271 insertions(+), 7 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java index 6592d9bc2e..151b71eb81 100644 --- a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java +++ b/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java @@ -38,9 +38,9 @@ public class JavaFlumeEventCount { flumeStream.count(); - flumeStream.count().map(new Function() { + flumeStream.count().map(new Function() { @Override - public String call(Integer in) { + public String call(Long in) { return "Received " + in + " flume events."; } }).print(); diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index 80d8865725..4257ecd583 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -1,7 +1,7 @@ package spark.streaming.api.java import java.util.{List => JList} -import java.lang.{Integer => JInt} +import java.lang.{Long => JLong} import scala.collection.JavaConversions._ @@ -17,8 +17,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable def dstream: DStream[T] - implicit def scalaIntToJavaInteger(in: DStream[Int]): JavaDStream[JInt] = { - in.map(new JInt(_)) + implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[JLong] = { + in.map(new JLong(_)) } /** @@ -31,14 +31,14 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable * Returns a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): JavaDStream[JInt] = dstream.count() + def count(): JavaDStream[JLong] = dstream.count() /** * Returns a new DStream in which each RDD has a single element generated by counting the number * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the * window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[JInt] = { + def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[JLong] = { dstream.countByWindow(windowDuration, slideDuration) } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index eeb1f07939..c761fdd3bd 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -85,21 +85,66 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( // Methods only for PairDStream's // ======================================================================= + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. Hash partitioning is + * used to generate the RDDs with Spark's default number of partitions. + */ def groupByKey(): JavaPairDStream[K, JList[V]] = dstream.groupByKey().mapValues(seqAsJavaList _) + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. Hash partitioning is + * used to generate the RDDs with `numPartitions` partitions. + */ def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _) + /** + * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs are grouped into a + * single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]] + * is used to control the partitioning of each RDD. + */ def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JList[V]] = dstream.groupByKey(partitioner).mapValues(seqAsJavaList _) + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + */ def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] = dstream.reduceByKey(func) + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + */ def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairDStream[K, V] = dstream.reduceByKey(func, numPartitions) + /** + * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. + * Therefore, the values for each key in `this` DStream's RDDs is merged using the + * associative reduce function to generate the RDDs of the new DStream. + * [[spark.Partitioner]] is used to control the partitioning of each RDD. + */ + def reduceByKey(func: JFunction2[V, V, V], partitioner: Partitioner): JavaPairDStream[K, V] = { + dstream.reduceByKey(func, partitioner) + } + + /** + * Generic function to combine elements of each key in DStream's RDDs using custom function. + * This is similar to the combineByKey for RDDs. Please refer to combineByKey in + * [[spark.PairRDDFunctions]] for more information. + */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C], @@ -110,25 +155,78 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner) } + /** + * Creates a new DStream by counting the number of values of each key in each RDD + * of `this` DStream. Hash partitioning is used to generate the RDDs. + */ def countByKey(numPartitions: Int): JavaPairDStream[K, JLong] = { JavaPairDStream.scalaToJavaLong(dstream.countByKey(numPartitions)); } + + /** + * Creates a new DStream by counting the number of values of each key in each RDD + * of `this` DStream. Hash partitioning is used to generate the RDDs with Spark's + * `numPartitions` partitions. + */ def countByKey(): JavaPairDStream[K, JLong] = { JavaPairDStream.scalaToJavaLong(dstream.countByKey()); } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * The new DStream generates RDDs with the same interval as this DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ + def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JList[V]] = { + dstream.groupByKeyAndWindow(windowDuration).mapValues(seqAsJavaList _) + } + + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) : JavaPairDStream[K, JList[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(seqAsJavaList _) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) :JavaPairDStream[K, JList[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) .mapValues(seqAsJavaList _) } + /** + * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def groupByKeyAndWindow( windowDuration: Duration, slideDuration: Duration, @@ -138,11 +236,31 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( .mapValues(seqAsJavaList _) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * The new DStream generates RDDs with the same interval as this DStream. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + */ def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration) :JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], windowDuration: Duration, @@ -151,6 +269,18 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], windowDuration: Duration, @@ -160,6 +290,17 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, numPartitions) } + /** + * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], windowDuration: Duration, @@ -169,6 +310,24 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, partitioner) } + + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], @@ -178,6 +337,24 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration) } + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], @@ -193,6 +370,23 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( numPartitions) } + /** + * Creates a new DStream by reducing over a window in a smarter way. + * The reduced value of over a new window is calculated incrementally by using the + * old window's reduce value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], @@ -208,16 +402,38 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( partitioner) } + /** + * Creates a new DStream by counting the number of values for each key over a window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) : JavaPairDStream[K, JLong] = { JavaPairDStream.scalaToJavaLong(dstream.countByKeyAndWindow(windowDuration, slideDuration)) } + /** + * Creates a new DStream by counting the number of values for each key over a window. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions Number of partitions of each RDD in the new DStream. + */ def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) : JavaPairDStream[K, Long] = { dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) } + + // TODO: Update State + def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { implicit val cm: ClassManifest[U] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] @@ -232,12 +448,26 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.flatMapValues(fn) } + /** + * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for + * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD + * will contains a tuple with the list of values for that key in both RDDs. + * HashPartitioner is used to partition each generated RDD into default number of partitions. + */ def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = { implicit val cm: ClassManifest[W] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) } + /** + * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for + * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD + * will contains a tuple with the list of values for that key in both RDDs. + * Partitioner is used to partition each generated RDD. + */ def cogroup[W](other: JavaPairDStream[K, W], partitioner: Partitioner) : JavaPairDStream[K, (JList[V], JList[W])] = { implicit val cm: ClassManifest[W] = @@ -246,12 +476,22 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) } + /** + * Joins `this` DStream with `other` DStream. Each RDD of the new DStream will + * be generated by joining RDDs from `this` and `other` DStreams. HashPartitioner is used + * to partition each generated RDD into default number of partitions. + */ def join[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, W)] = { implicit val cm: ClassManifest[W] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] dstream.join(other.dstream) } + /** + * Joins `this` DStream with `other` DStream, that is, each RDD of the new DStream will + * be generated by joining RDDs from `this` and other DStream. Uses the given + * Partitioner to partition each generated RDD. + */ def join[W](other: JavaPairDStream[K, W], partitioner: Partitioner) : JavaPairDStream[K, (V, W)] = { implicit val cm: ClassManifest[W] = @@ -259,10 +499,18 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.join(other.dstream, partitioner) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) { dstream.saveAsHadoopFiles(prefix, suffix) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ def saveAsHadoopFiles( prefix: String, suffix: String, @@ -272,6 +520,10 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + */ def saveAsHadoopFiles( prefix: String, suffix: String, @@ -282,10 +534,18 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ def saveAsNewAPIHadoopFiles( prefix: String, suffix: String, @@ -295,6 +555,10 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } + /** + * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + */ def saveAsNewAPIHadoopFiles( prefix: String, suffix: String, -- cgit v1.2.3 From 6069446356d1daf28054b87ff1a3bf724a22df03 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 14 Jan 2013 10:34:13 -0800 Subject: Making comments consistent w/ Spark style --- .../src/main/scala/spark/streaming/DStream.scala | 52 +++---- .../spark/streaming/PairDStreamFunctions.scala | 170 +++++++++------------ .../spark/streaming/api/java/JavaDStream.scala | 14 +- .../spark/streaming/api/java/JavaDStreamLike.scala | 26 ++-- .../spark/streaming/api/java/JavaPairDStream.scala | 166 +++++++++----------- 5 files changed, 196 insertions(+), 232 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index fbe3cebd6d..036763fe2f 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -98,10 +98,10 @@ abstract class DStream[T: ClassManifest] ( this } - /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_SER) - /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): DStream[T] = persist() /** @@ -119,7 +119,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * This method initializes the DStream by setting the "zero" time, based on which + * Initialize the DStream by setting the "zero" time, based on which * the validity of future times is calculated. This method also recursively initializes * its parent DStreams. */ @@ -244,7 +244,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Retrieves a precomputed RDD of this DStream, or computes the RDD. This is an internal + * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal * method that should not be called directly. */ protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { @@ -283,7 +283,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Generates a SparkStreaming job for the given time. This is an internal method that + * Generate a SparkStreaming job for the given time. This is an internal method that * should not be called directly. This default implementation creates a job * that materializes the corresponding RDD. Subclasses of DStream may override this * (eg. ForEachDStream). @@ -302,7 +302,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Dereferences RDDs that are older than rememberDuration. + * Dereference RDDs that are older than rememberDuration. */ protected[streaming] def forgetOldRDDs(time: Time) { val keys = generatedRDDs.keys @@ -328,7 +328,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of + * Refresh the list of checkpointed RDDs that will be saved along with checkpoint of * this stream. This is an internal method that should not be called directly. This is * a default implementation that saves only the file names of the checkpointed RDDs to * checkpointData. Subclasses of DStream (especially those of InputDStream) may override @@ -373,7 +373,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Restores the RDDs in generatedRDDs from the checkpointData. This is an internal method + * Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method * that should not be called directly. This is a default implementation that recreates RDDs * from the checkpoint file names stored in checkpointData. Subclasses of DStream that * override the updateCheckpointData() method would also need to override this method. @@ -425,20 +425,20 @@ abstract class DStream[T: ClassManifest] ( // DStream operations // ======================================================================= - /** Returns a new DStream by applying a function to all elements of this DStream. */ + /** Return a new DStream by applying a function to all elements of this DStream. */ def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { new MappedDStream(this, ssc.sc.clean(mapFunc)) } /** - * Returns a new DStream by applying a function to all elements of this DStream, + * Return a new DStream by applying a function to all elements of this DStream, * and then flattening the results */ def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) } - /** Returns a new DStream containing only the elements that satisfy a predicate. */ + /** Return a new DStream containing only the elements that satisfy a predicate. */ def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) /** @@ -461,20 +461,20 @@ abstract class DStream[T: ClassManifest] ( } /** - * Returns a new DStream in which each RDD has a single element generated by reducing each RDD + * Return a new DStream in which each RDD has a single element generated by reducing each RDD * of this DStream. */ def reduce(reduceFunc: (T, T) => T): DStream[T] = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) /** - * Returns a new DStream in which each RDD has a single element generated by counting each RDD + * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _) /** - * Applies a function to each RDD in this DStream. This is an output operator, so + * Apply a function to each RDD in this DStream. This is an output operator, so * this DStream will be registered as an output stream and therefore materialized. */ def foreach(foreachFunc: RDD[T] => Unit) { @@ -482,7 +482,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Applies a function to each RDD in this DStream. This is an output operator, so + * Apply a function to each RDD in this DStream. This is an output operator, so * this DStream will be registered as an output stream and therefore materialized. */ def foreach(foreachFunc: (RDD[T], Time) => Unit) { @@ -492,7 +492,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Returns a new DStream in which each RDD is generated by applying a function + * Return a new DStream in which each RDD is generated by applying a function * on each RDD of this DStream. */ def transform[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = { @@ -500,7 +500,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Returns a new DStream in which each RDD is generated by applying a function + * Return a new DStream in which each RDD is generated by applying a function * on each RDD of this DStream. */ def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { @@ -508,7 +508,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Prints the first ten elements of each RDD generated in this DStream. This is an output + * Print the first ten elements of each RDD generated in this DStream. This is an output * operator, so this DStream will be registered as an output stream and there materialized. */ def print() { @@ -545,7 +545,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Returns a new DStream which computed based on tumbling window on this DStream. + * Return a new DStream which computed based on tumbling window on this DStream. * This is equivalent to window(batchTime, batchTime). * @param batchDuration tumbling window duration; must be a multiple of this DStream's * batching interval @@ -553,7 +553,7 @@ abstract class DStream[T: ClassManifest] ( def tumble(batchDuration: Duration): DStream[T] = window(batchDuration, batchDuration) /** - * Returns a new DStream in which each RDD has a single element generated by reducing all + * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a window over this DStream. windowDuration and slideDuration are as defined * in the window() operation. This is equivalent to * window(windowDuration, slideDuration).reduce(reduceFunc) @@ -578,7 +578,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Returns a new DStream in which each RDD has a single element generated by counting the number + * Return a new DStream in which each RDD has a single element generated by counting the number * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the * window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ @@ -587,20 +587,20 @@ abstract class DStream[T: ClassManifest] ( } /** - * Returns a new DStream by unifying data of another DStream with this DStream. + * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same slideDuration as this DStream. */ def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) /** - * Returns all the RDDs defined by the Interval object (both end times included) + * Return all the RDDs defined by the Interval object (both end times included) */ protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = { slice(interval.beginTime, interval.endTime) } /** - * Returns all the RDDs between 'fromTime' to 'toTime' (both included) + * Return all the RDDs between 'fromTime' to 'toTime' (both included) */ def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() @@ -616,7 +616,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Saves each RDD in this DStream as a Sequence file of serialized objects. + * Save each RDD in this DStream as a Sequence file of serialized objects. * The file name at each batch interval is generated based on `prefix` and * `suffix`: "prefix-TIME_IN_MS.suffix". */ @@ -629,7 +629,7 @@ abstract class DStream[T: ClassManifest] ( } /** - * Saves each RDD in this DStream as at text file, using string representation + * Save each RDD in this DStream as at text file, using string representation * of elements. The file name at each batch interval is generated based on * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 3952457339..f63279512b 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -26,29 +26,23 @@ extends Serializable { } /** - * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs are grouped into a - * single sequence to generate the RDDs of the new DStream. Hash partitioning is - * used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner()) } /** - * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs are grouped into a - * single sequence to generate the RDDs of the new DStream. Hash partitioning is - * used to generate the RDDs with `numPartitions` partitions. + * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { groupByKey(defaultPartitioner(numPartitions)) } /** - * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs are grouped into a - * single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]] + * Create a new DStream by applying `groupByKey` on each RDD. The supplied [[spark.Partitioner]] * is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { @@ -60,30 +54,27 @@ extends Serializable { } /** - * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs is merged using the - * associative reduce function to generate the RDDs of the new DStream. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the associative reduce function. Hash partitioning is used to generate the RDDs + * with Spark's default number of partitions. */ def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner()) } /** - * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs is merged using the - * associative reduce function to generate the RDDs of the new DStream. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs + * with `numPartitions` partitions. */ def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) } /** - * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs is merged using the - * associative reduce function to generate the RDDs of the new DStream. - * [[spark.Partitioner]] is used to control the partitioning of each RDD. + * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the + * partitioning of each RDD. */ def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) @@ -91,9 +82,9 @@ extends Serializable { } /** - * Generic function to combine elements of each key in DStream's RDDs using custom function. - * This is similar to the combineByKey for RDDs. Please refer to combineByKey in - * [[spark.PairRDDFunctions]] for more information. + * Combine elements of each key in DStream's RDDs using custom function. This is similar to the + * combineByKey for RDDs. Please refer to combineByKey in [[spark.PairRDDFunctions]] for more + * information. */ def combineByKey[C: ClassManifest]( createCombiner: V => C, @@ -104,19 +95,18 @@ extends Serializable { } /** - * Creates a new DStream by counting the number of values of each key in each RDD - * of `this` DStream. Hash partitioning is used to generate the RDDs with Spark's - * `numPartitions` partitions. + * Create a new DStream by counting the number of values of each key in each RDD. Hash + * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions. */ def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = { self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) } /** - * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.groupByKey()` but applies it over a sliding window. - * The new DStream generates RDDs with the same interval as this DStream. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to + * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs + * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ @@ -125,9 +115,9 @@ extends Serializable { } /** - * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.groupByKey()` but applies it over a sliding window. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `groupByKey` over a sliding window. Similar to + * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -139,8 +129,8 @@ extends Serializable { } /** - * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -158,8 +148,8 @@ extends Serializable { } /** - * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -176,10 +166,10 @@ extends Serializable { } /** - * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. - * The new DStream generates RDDs with the same interval as this DStream. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream + * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate + * the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -192,9 +182,9 @@ extends Serializable { } /** - * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -211,9 +201,9 @@ extends Serializable { } /** - * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with `numPartitions` partitions. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -232,8 +222,8 @@ extends Serializable { } /** - * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to + * `DStream.reduceByKey()`, but applies it over a sliding window. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -255,9 +245,8 @@ extends Serializable { } /** - * Creates a new DStream by reducing over a window in a smarter way. - * The reduced value of over a new window is calculated incrementally by using the - * old window's reduce value : + * Create a new DStream by reducing over a using incremental computation. + * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. @@ -283,9 +272,8 @@ extends Serializable { } /** - * Creates a new DStream by reducing over a window in a smarter way. - * The reduced value of over a new window is calculated incrementally by using the - * old window's reduce value : + * Create a new DStream by reducing over a using incremental computation. + * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. @@ -313,9 +301,8 @@ extends Serializable { } /** - * Creates a new DStream by reducing over a window in a smarter way. - * The reduced value of over a new window is calculated incrementally by using the - * old window's reduce value : + * Create a new DStream by reducing over a using incremental computation. + * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. @@ -344,7 +331,7 @@ extends Serializable { } /** - * Creates a new DStream by counting the number of values for each key over a window. + * Create a new DStream by counting the number of values for each key over a window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -369,10 +356,9 @@ extends Serializable { } /** - * Creates a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of the key from - * `this` DStream. Hash partitioning is used to generate the RDDs with Spark's default - * number of partitions. + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. * @tparam S State type @@ -384,9 +370,9 @@ extends Serializable { } /** - * Creates a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of the key from - * `this` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. * @param numPartitions Number of partitions of each RDD in the new DStream. @@ -400,9 +386,9 @@ extends Serializable { } /** - * Creates a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of the key from - * `this` DStream. [[spark.Partitioner]] is used to control the partitioning of each RDD. + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key. + * [[spark.Partitioner]] is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. @@ -419,9 +405,9 @@ extends Serializable { } /** - * Creates a new "state" DStream where the state for each key is updated by applying - * the given function on the previous state of the key and the new values of the key from - * `this` DStream. [[spark.Partitioner]] is used to control the partitioning of each RDD. + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * [[spark.Paxrtitioner]] is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. Note, that * this function may generate a different a tuple with a different key @@ -451,22 +437,19 @@ extends Serializable { } /** - * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will - * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for - * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD - * will contains a tuple with the list of values for that key in both RDDs. - * HashPartitioner is used to partition each generated RDD into default number of partitions. + * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` + * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that + * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number + * of partitions. */ def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { cogroup(other, defaultPartitioner()) } /** - * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will - * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for - * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD - * will contains a tuple with the list of values for that key in both RDDs. - * Partitioner is used to partition each generated RDD. + * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` + * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that + * key in both RDDs. Partitioner is used to partition each generated RDD. */ def cogroup[W: ClassManifest]( other: DStream[(K, W)], @@ -488,8 +471,7 @@ extends Serializable { } /** - * Joins `this` DStream with `other` DStream. Each RDD of the new DStream will - * be generated by joining RDDs from `this` and `other` DStreams. HashPartitioner is used + * Join `this` DStream with `other` DStream. HashPartitioner is used * to partition each generated RDD into default number of partitions. */ def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = { @@ -497,7 +479,7 @@ extends Serializable { } /** - * Joins `this` DStream with `other` DStream, that is, each RDD of the new DStream will + * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will * be generated by joining RDDs from `this` and other DStream. Uses the given * Partitioner to partition each generated RDD. */ @@ -513,7 +495,7 @@ extends Serializable { } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" */ def saveAsHadoopFiles[F <: OutputFormat[K, V]]( @@ -524,7 +506,7 @@ extends Serializable { } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" */ def saveAsHadoopFiles( @@ -543,8 +525,8 @@ extends Serializable { } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( prefix: String, @@ -554,8 +536,8 @@ extends Serializable { } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsNewAPIHadoopFiles( prefix: String, diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index f85864df5d..32faef5670 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -9,20 +9,20 @@ import spark.storage.StorageLevel class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) extends JavaDStreamLike[T, JavaDStream[T]] { - /** Returns a new DStream containing only the elements that satisfy a predicate. */ + /** Return a new DStream containing only the elements that satisfy a predicate. */ def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = dstream.filter((x => f(x).booleanValue())) - /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): JavaDStream[T] = dstream.cache() - /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def persist(): JavaDStream[T] = dstream.cache() - /** Persists the RDDs of this DStream with the given storage level */ + /** Persist the RDDs of this DStream with the given storage level */ def persist(storageLevel: StorageLevel): JavaDStream[T] = dstream.persist(storageLevel) - /** Method that generates a RDD for the given duration */ + /** Generate an RDD for the given duration */ def compute(validTime: Time): JavaRDD[T] = { dstream.compute(validTime) match { case Some(rdd) => new JavaRDD(rdd) @@ -51,7 +51,7 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM dstream.window(windowDuration, slideDuration) /** - * Returns a new DStream which computed based on tumbling window on this DStream. + * Return a new DStream which computed based on tumbling window on this DStream. * This is equivalent to window(batchDuration, batchDuration). * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval */ @@ -59,7 +59,7 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM dstream.tumble(batchDuration) /** - * Returns a new DStream by unifying data of another DStream with this DStream. + * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. */ def union(that: JavaDStream[T]): JavaDStream[T] = diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index 4257ecd583..32df665a98 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -22,19 +22,19 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable } /** - * Prints the first ten elements of each RDD generated in this DStream. This is an output + * Print the first ten elements of each RDD generated in this DStream. This is an output * operator, so this DStream will be registered as an output stream and there materialized. */ def print() = dstream.print() /** - * Returns a new DStream in which each RDD has a single element generated by counting each RDD + * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ def count(): JavaDStream[JLong] = dstream.count() /** - * Returns a new DStream in which each RDD has a single element generated by counting the number + * Return a new DStream in which each RDD has a single element generated by counting the number * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the * window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ @@ -50,15 +50,15 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable def glom(): JavaDStream[JList[T]] = new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - /** Returns the StreamingContext associated with this DStream */ + /** Return the StreamingContext associated with this DStream */ def context(): StreamingContext = dstream.context() - /** Returns a new DStream by applying a function to all elements of this DStream. */ + /** Return a new DStream by applying a function to all elements of this DStream. */ def map[R](f: JFunction[T, R]): JavaDStream[R] = { new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType()) } - /** Returns a new DStream by applying a function to all elements of this DStream. */ + /** Return a new DStream by applying a function to all elements of this DStream. */ def map[K, V](f: PairFunction[T, K, V]): JavaPairDStream[K, V] = { def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType()) @@ -86,13 +86,13 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable } /** - * Returns a new DStream in which each RDD has a single element generated by reducing each RDD + * Return a new DStream in which each RDD has a single element generated by reducing each RDD * of this DStream. */ def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f) /** - * Returns a new DStream in which each RDD has a single element generated by reducing all + * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a window over this DStream. windowDuration and slideDuration are as defined in the * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc) */ @@ -106,14 +106,14 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable } /** - * Returns all the RDDs between 'fromDuration' to 'toDuration' (both included) + * Return all the RDDs between 'fromDuration' to 'toDuration' (both included) */ def slice(fromDuration: Duration, toDuration: Duration): JList[JavaRDD[T]] = { new util.ArrayList(dstream.slice(fromDuration, toDuration).map(new JavaRDD(_)).toSeq) } /** - * Applies a function to each RDD in this DStream. This is an output operator, so + * Apply a function to each RDD in this DStream. This is an output operator, so * this DStream will be registered as an output stream and therefore materialized. */ def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) { @@ -121,7 +121,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable } /** - * Applies a function to each RDD in this DStream. This is an output operator, so + * Apply a function to each RDD in this DStream. This is an output operator, so * this DStream will be registered as an output stream and therefore materialized. */ def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) { @@ -129,7 +129,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable } /** - * Returns a new DStream in which each RDD is generated by applying a function + * Return a new DStream in which each RDD is generated by applying a function * on each RDD of this DStream. */ def transform[U](transformFunc: JFunction[JavaRDD[T], JavaRDD[U]]): JavaDStream[U] = { @@ -141,7 +141,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable } /** - * Returns a new DStream in which each RDD is generated by applying a function + * Return a new DStream in which each RDD is generated by applying a function * on each RDD of this DStream. */ def transform[U](transformFunc: JFunction2[JavaRDD[T], Time, JavaRDD[U]]): JavaDStream[U] = { diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index c761fdd3bd..16b476ec90 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -86,19 +86,15 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( // ======================================================================= /** - * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs are grouped into a - * single sequence to generate the RDDs of the new DStream. Hash partitioning is - * used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): JavaPairDStream[K, JList[V]] = dstream.groupByKey().mapValues(seqAsJavaList _) /** - * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs are grouped into a - * single sequence to generate the RDDs of the new DStream. Hash partitioning is - * used to generate the RDDs with `numPartitions` partitions. + * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _) @@ -113,37 +109,34 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.groupByKey(partitioner).mapValues(seqAsJavaList _) /** - * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs is merged using the - * associative reduce function to generate the RDDs of the new DStream. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the associative reduce function. Hash partitioning is used to generate the RDDs + * with Spark's default number of partitions. */ def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] = dstream.reduceByKey(func) /** - * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs is merged using the - * associative reduce function to generate the RDDs of the new DStream. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs + * with `numPartitions` partitions. */ def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairDStream[K, V] = dstream.reduceByKey(func, numPartitions) /** - * Creates a new DStream by applying `reduceByKey` on each RDD of `this` DStream. - * Therefore, the values for each key in `this` DStream's RDDs is merged using the - * associative reduce function to generate the RDDs of the new DStream. - * [[spark.Partitioner]] is used to control the partitioning of each RDD. + * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the + * partitioning of each RDD. */ def reduceByKey(func: JFunction2[V, V, V], partitioner: Partitioner): JavaPairDStream[K, V] = { dstream.reduceByKey(func, partitioner) } /** - * Generic function to combine elements of each key in DStream's RDDs using custom function. - * This is similar to the combineByKey for RDDs. Please refer to combineByKey in - * [[spark.PairRDDFunctions]] for more information. + * Combine elements of each key in DStream's RDDs using custom function. This is similar to the + * combineByKey for RDDs. Please refer to combineByKey in [[spark.PairRDDFunctions]] for more + * information. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -156,8 +149,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by counting the number of values of each key in each RDD - * of `this` DStream. Hash partitioning is used to generate the RDDs. + * Create a new DStream by counting the number of values of each key in each RDD. Hash + * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions. */ def countByKey(numPartitions: Int): JavaPairDStream[K, JLong] = { JavaPairDStream.scalaToJavaLong(dstream.countByKey(numPartitions)); @@ -165,19 +158,18 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** - * Creates a new DStream by counting the number of values of each key in each RDD - * of `this` DStream. Hash partitioning is used to generate the RDDs with Spark's - * `numPartitions` partitions. + * Create a new DStream by counting the number of values of each key in each RDD. Hash + * partitioning is used to generate the RDDs with the default number of partitions. */ def countByKey(): JavaPairDStream[K, JLong] = { JavaPairDStream.scalaToJavaLong(dstream.countByKey()); } /** - * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.groupByKey()` but applies it over a sliding window. - * The new DStream generates RDDs with the same interval as this DStream. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to + * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs + * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ @@ -186,9 +178,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.groupByKey()` but applies it over a sliding window. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `groupByKey` over a sliding window. Similar to + * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -201,8 +193,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -218,8 +210,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by applying `groupByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.groupByKey()` but applies it over a sliding window. + * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -237,10 +229,10 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. - * The new DStream generates RDDs with the same interval as this DStream. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream + * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate + * the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -251,9 +243,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. - * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -270,9 +262,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to + * generate the RDDs with `numPartitions` partitions. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -291,8 +283,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. - * This is similar to `DStream.reduceByKey()` but applies it over a sliding window. + * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to + * `DStream.reduceByKey()`, but applies it over a sliding window. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -310,11 +302,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, partitioner) } - /** - * Creates a new DStream by reducing over a window in a smarter way. - * The reduced value of over a new window is calculated incrementally by using the - * old window's reduce value : + * Create a new DStream by reducing over a using incremental computation. + * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. @@ -338,9 +328,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by reducing over a window in a smarter way. - * The reduced value of over a new window is calculated incrementally by using the - * old window's reduce value : + * Create a new DStream by reducing over a using incremental computation. + * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. @@ -371,9 +360,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by reducing over a window in a smarter way. - * The reduced value of over a new window is calculated incrementally by using the - * old window's reduce value : + * Create a new DStream by reducing over a using incremental computation. + * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. @@ -403,7 +391,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by counting the number of values for each key over a window. + * Create a new DStream by counting the number of values for each key over a window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -417,7 +405,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by counting the number of values for each key over a window. + * Create a new DStream by counting the number of values for each key over a window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -449,24 +437,19 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will - * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for - * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD - * will contains a tuple with the list of values for that key in both RDDs. - * HashPartitioner is used to partition each generated RDD into default number of partitions. + * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` + * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that + * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number + * of partitions. */ - def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = { - implicit val cm: ClassManifest[W] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] - dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) + def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { + cogroup(other, defaultPartitioner()) } /** - * Cogroups `this` DStream with `other` DStream. Each RDD of the new DStream will - * be generated by cogrouping RDDs from`this`and `other` DStreams. Therefore, for - * each key k in corresponding RDDs of `this` or `other` DStreams, the generated RDD - * will contains a tuple with the list of values for that key in both RDDs. - * Partitioner is used to partition each generated RDD. + * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` + * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that + * key in both RDDs. Partitioner is used to partition each generated RDD. */ def cogroup[W](other: JavaPairDStream[K, W], partitioner: Partitioner) : JavaPairDStream[K, (JList[V], JList[W])] = { @@ -477,8 +460,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Joins `this` DStream with `other` DStream. Each RDD of the new DStream will - * be generated by joining RDDs from `this` and `other` DStreams. HashPartitioner is used + * Join `this` DStream with `other` DStream. HashPartitioner is used * to partition each generated RDD into default number of partitions. */ def join[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, W)] = { @@ -488,7 +470,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Joins `this` DStream with `other` DStream, that is, each RDD of the new DStream will + * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will * be generated by joining RDDs from `this` and other DStream. Uses the given * Partitioner to partition each generated RDD. */ @@ -500,16 +482,16 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) { dstream.saveAsHadoopFiles(prefix, suffix) } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsHadoopFiles( prefix: String, @@ -521,8 +503,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsHadoopFiles( prefix: String, @@ -535,16 +517,16 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix) } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsNewAPIHadoopFiles( prefix: String, @@ -556,8 +538,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Saves each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated - * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". + * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is + * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ def saveAsNewAPIHadoopFiles( prefix: String, -- cgit v1.2.3 From ae5290f4a2fbeb51f5dc6e7add38f9c012ab7311 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 14 Jan 2013 13:31:32 -0800 Subject: Bug fix --- .../src/main/scala/spark/streaming/api/java/JavaPairDStream.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 16b476ec90..6f4336a011 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -419,7 +419,6 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) } - // TODO: Update State def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { @@ -443,7 +442,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * of partitions. */ def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { - cogroup(other, defaultPartitioner()) + dstream.cogroup(other) } /** -- cgit v1.2.3 From 38d9a3a8630a38aa0cb9e6a13256816cfa9ab5a6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 14 Jan 2013 13:31:49 -0800 Subject: Remove AnyRef constraint in updateState --- .../src/main/scala/spark/streaming/PairDStreamFunctions.scala | 8 ++++---- .../src/main/scala/spark/streaming/dstream/StateDStream.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index f63279512b..fbcf061126 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -363,7 +363,7 @@ extends Serializable { * corresponding state key-value pair will be eliminated. * @tparam S State type */ - def updateStateByKey[S <: AnyRef : ClassManifest]( + def updateStateByKey[S: ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S] ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner()) @@ -378,7 +378,7 @@ extends Serializable { * @param numPartitions Number of partitions of each RDD in the new DStream. * @tparam S State type */ - def updateStateByKey[S <: AnyRef : ClassManifest]( + def updateStateByKey[S: ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int ): DStream[(K, S)] = { @@ -394,7 +394,7 @@ extends Serializable { * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. * @tparam S State type */ - def updateStateByKey[S <: AnyRef : ClassManifest]( + def updateStateByKey[S: ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner ): DStream[(K, S)] = { @@ -417,7 +417,7 @@ extends Serializable { * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. * @tparam S State type */ - def updateStateByKey[S <: AnyRef : ClassManifest]( + def updateStateByKey[S: ClassManifest]( updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala index a1ec2f5454..b4506c74aa 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -7,7 +7,7 @@ import spark.storage.StorageLevel import spark.streaming.{Duration, Time, DStream} private[streaming] -class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( +class StateDStream[K: ClassManifest, V: ClassManifest, S: ClassManifest]( parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, -- cgit v1.2.3 From 8ad6220bd376b04084604cf49b4537c97a16257d Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 14 Jan 2013 14:54:47 -0800 Subject: Bugfix --- .../src/main/scala/spark/streaming/api/java/JavaPairDStream.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 6f4336a011..0cccb083c5 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -441,8 +441,10 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number * of partitions. */ - def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { - dstream.cogroup(other) + def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = { + implicit val cm: ClassManifest[W] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]] + dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) } /** -- cgit v1.2.3 From a0013beb039c8569e9e69f96fce0c341d1a1d180 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 14 Jan 2013 15:15:02 -0800 Subject: Stash --- .../spark/streaming/api/java/JavaPairDStream.scala | 29 ++++++++++++++++++- .../test/scala/spark/streaming/JavaAPISuite.java | 33 ++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 0cccb083c5..49a0f27b5b 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -15,6 +15,7 @@ import org.apache.hadoop.conf.Configuration import spark.api.java.JavaPairRDD import spark.storage.StorageLevel import java.lang +import com.google.common.base.Optional class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifiest: ClassManifest[K], @@ -419,7 +420,33 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) } - // TODO: Update State + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @tparam S State type + */ + def updateStateByKey[S](updateFunc: JFunction2[JList[V], Optional[S], Optional[S]]) + : JavaPairDStream[K, S] = { + implicit val cm: ClassManifest[S] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[S]] + + def scalaFunc(values: Seq[V], state: Option[S]): Option[S] = { + val list: JList[V] = values + val scalaState: Optional[S] = state match { + case Some(s) => Optional.of(s) + case _ => Optional.absent() + } + val result: Optional[S] = updateFunc.apply(list, scalaState) + result.isPresent match { + case true => Some(result.get()) + case _ => None + } + } + dstream.updateStateByKey(scalaFunc _) + } def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { implicit val cm: ClassManifest[U] = diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 7475b9536b..d95ab485f8 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -1,5 +1,6 @@ package spark.streaming; +import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; @@ -551,6 +552,38 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, result); } + @Test + public void testUpdateStateByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream updated = pairStream.updateStateByKey( + new Function2, Optional, Optional>(){ + @Override + public Optional call(List values, Optional state) { + int out = 0; + for (Integer v: values) { + out = out + v; + } + return Optional.of(out); + } + }); + JavaTestUtils.attachTestOutputStream(updated); + List>> result = JavaTestUtils.runStreams(sc, 3, 3); + + Assert.assertEquals(expected, result); + } + @Test public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; -- cgit v1.2.3 From 273fb5cc109ac0a032f84c1566ae908cd0eb27b6 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Thu, 3 Jan 2013 14:09:56 -0800 Subject: Throw FetchFailedException for cached missing locs --- core/src/main/scala/spark/MapOutputTracker.scala | 36 +++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 70eb9f702e..9f2aa76830 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -139,8 +139,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case e: InterruptedException => } } - return mapStatuses.get(shuffleId).map(status => - (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, + mapStatuses.get(shuffleId)) } else { fetching += shuffleId } @@ -156,21 +156,15 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea fetchedStatuses = deserializeStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) - if (fetchedStatuses.contains(null)) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing an output location for shuffle " + shuffleId)) - } } finally { fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } } - return fetchedStatuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } else { - return statuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) } } @@ -258,6 +252,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea private[spark] object MapOutputTracker { private val LOG_BASE = 1.1 + // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If + // any of the statuses is null (indicating a missing location due to a failed mapper), + // throw a FetchFailedException. + def convertMapStatuses( + shuffleId: Int, + reduceId: Int, + statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + if (statuses == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } + statuses.map { + status => + if (status == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing an output location for shuffle " + shuffleId)) + } else { + (status.address, decompressSize(status.compressedSizes(reduceId))) + } + } + } + /** * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. * We do this by encoding the log base 1.1 of the size as an integer, which can support -- cgit v1.2.3 From 7ba34bc007ec10d12b2a871749f32232cdbc0d9c Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 14 Jan 2013 15:24:08 -0800 Subject: Additional tests for MapOutputTracker. --- .../test/scala/spark/MapOutputTrackerSuite.scala | 82 +++++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 5b4b198960..6c6f82e274 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -1,12 +1,18 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import akka.actor._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId +import spark.util.AkkaUtils -class MapOutputTrackerSuite extends FunSuite { +class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { + after { + System.clearProperty("spark.master.port") + } + test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) assert(MapOutputTracker.compressSize(1L) === 1) @@ -71,6 +77,78 @@ class MapOutputTrackerSuite extends FunSuite { // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[Exception] { tracker.getServerStatuses(10, 1) } + intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + } + + test("remote fetch") { + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val masterTracker = new MapOutputTracker(actorSystem, true) + val slaveTracker = new MapOutputTracker(actorSystem, false) + masterTracker.registerShuffle(10, 1) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((new BlockManagerId("hostA", 1000), size1000))) + + masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + } + + test("simulatenous fetch fails") { + val dummyActorSystem = ActorSystem("testDummy") + val dummyTracker = new MapOutputTracker(dummyActorSystem, true) + dummyTracker.registerShuffle(10, 1) + // val compressedSize1000 = MapOutputTracker.compressSize(1000L) + // val size100 = MapOutputTracker.decompressSize(compressedSize1000) + // dummyTracker.registerMapOutput(10, 0, new MapStatus( + // new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + val serializedMessage = dummyTracker.getSerializedLocations(10) + + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val delayResponseLock = new java.lang.Object + val delayResponseActor = actorSystem.actorOf(Props(new Actor { + override def receive = { + case GetMapOutputStatuses(shuffleId: Int, requester: String) => + delayResponseLock.synchronized { + sender ! serializedMessage + } + } + }), name = "MapOutputTracker") + val slaveTracker = new MapOutputTracker(actorSystem, false) + var firstFailed = false + var secondFailed = false + val firstFetch = new Thread { + override def run() { + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + firstFailed = true + } + } + val secondFetch = new Thread { + override def run() { + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + secondFailed = true + } + } + delayResponseLock.synchronized { + firstFetch.start + secondFetch.start + } + firstFetch.join + secondFetch.join + assert(firstFailed && secondFailed) } } -- cgit v1.2.3 From b0389997972d383c3aaa87924b725dee70b18d8e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 14 Jan 2013 17:04:44 -0800 Subject: Fix accidental spark.master.host reuse --- core/src/test/scala/spark/MapOutputTrackerSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 6c6f82e274..aa1d8ac7e6 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -81,6 +81,7 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("remote fetch") { + System.clearProperty("spark.master.host") val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) System.setProperty("spark.master.port", boundPort.toString) @@ -107,6 +108,7 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("simulatenous fetch fails") { + System.clearProperty("spark.master.host") val dummyActorSystem = ActorSystem("testDummy") val dummyTracker = new MapOutputTracker(dummyActorSystem, true) dummyTracker.registerShuffle(10, 1) -- cgit v1.2.3 From 1638fcb0dce296da22ffc90127d5148a8fab745e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Jan 2013 17:18:39 -0800 Subject: Fixed updateStateByKey to work with primitive types. --- .../src/main/scala/spark/streaming/PairDStreamFunctions.scala | 8 ++++---- .../src/main/scala/spark/streaming/dstream/StateDStream.scala | 2 +- .../src/test/scala/spark/streaming/BasicOperationsSuite.scala | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 3952457339..3dbef69868 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -377,7 +377,7 @@ extends Serializable { * corresponding state key-value pair will be eliminated. * @tparam S State type */ - def updateStateByKey[S <: AnyRef : ClassManifest]( + def updateStateByKey[S: ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S] ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner()) @@ -392,7 +392,7 @@ extends Serializable { * @param numPartitions Number of partitions of each RDD in the new DStream. * @tparam S State type */ - def updateStateByKey[S <: AnyRef : ClassManifest]( + def updateStateByKey[S: ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int ): DStream[(K, S)] = { @@ -408,7 +408,7 @@ extends Serializable { * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. * @tparam S State type */ - def updateStateByKey[S <: AnyRef : ClassManifest]( + def updateStateByKey[S: ClassManifest]( updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner ): DStream[(K, S)] = { @@ -431,7 +431,7 @@ extends Serializable { * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. * @tparam S State type */ - def updateStateByKey[S <: AnyRef : ClassManifest]( + def updateStateByKey[S: ClassManifest]( updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala index a1ec2f5454..b4506c74aa 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -7,7 +7,7 @@ import spark.storage.StorageLevel import spark.streaming.{Duration, Time, DStream} private[streaming] -class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( +class StateDStream[K: ClassManifest, V: ClassManifest, S: ClassManifest]( parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index f9e03c607d..f73f9b1823 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -151,10 +151,10 @@ class BasicOperationsSuite extends TestSuiteBase { ) val updateStateOperation = (s: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { - Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.foldLeft(0)(_ + _) + state.getOrElse(0)) } - s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) + s.map(x => (x, 1)).updateStateByKey[Int](updateFunc) } testOperation(inputData, updateStateOperation, outputData, true) -- cgit v1.2.3 From b77f7390a5a18c2b88fbc0c276c4dbc938560127 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Tue, 15 Jan 2013 09:04:32 +0200 Subject: Python ALS example --- python/examples/als.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100755 python/examples/als.py diff --git a/python/examples/als.py b/python/examples/als.py new file mode 100755 index 0000000000..284cf0d3a2 --- /dev/null +++ b/python/examples/als.py @@ -0,0 +1,71 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from os.path import realpath +import sys + +import numpy as np +from numpy.random import rand +from numpy import matrix +from pyspark import SparkContext + +LAMBDA = 0.01 # regularization +np.random.seed(42) + +def rmse(R, ms, us): + diff = R - ms * us.T + return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + +def update(i, vec, mat, ratings): + uu = mat.shape[0] + ff = mat.shape[1] + XtX = matrix(np.zeros((ff, ff))) + Xty = np.zeros((ff, 1)) + + for j in range(uu): + v = mat[j, :] + XtX += v.T * v + Xty += v.T * ratings[i, j] + XtX += np.eye(ff, ff) * LAMBDA * uu + return np.linalg.solve(XtX, Xty) + +if __name__ == "__main__": + if len(sys.argv) < 2: + print >> sys.stderr, \ + "Usage: PythonALS " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)]) + M = int(sys.argv[2]) if len(sys.argv) > 2 else 100 + U = int(sys.argv[3]) if len(sys.argv) > 3 else 500 + F = int(sys.argv[4]) if len(sys.argv) > 4 else 10 + ITERATIONS = int(sys.argv[5]) if len(sys.argv) > 5 else 5 + slices = int(sys.argv[6]) if len(sys.argv) > 6 else 2 + + print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \ + (M, U, F, ITERATIONS, slices) + + R = matrix(rand(M, F)) * matrix(rand(U, F).T) + ms = matrix(rand(M ,F)) + us = matrix(rand(U, F)) + + Rb = sc.broadcast(R) + msb = sc.broadcast(ms) + usb = sc.broadcast(us) + + for i in range(ITERATIONS): + ms = sc.parallelize(range(M), slices) \ + .map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \ + .collect() + ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being + # a 3-d array, we take the first 2 dims for the matrix + msb = sc.broadcast(ms) + + us = sc.parallelize(range(U), slices) \ + .map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \ + .collect() + us = matrix(np.array(us)[:, :, 0]) + usb = sc.broadcast(us) + + error = rmse(R, ms, us) + print "Iteration %d:" % i + print "\nRMSE: %5.4f\n" % error \ No newline at end of file -- cgit v1.2.3 From dd583b7ebf0e6620ec8e35424b59db451febe3e8 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 10:52:06 -0600 Subject: Call executeOnCompleteCallbacks in a finally block. --- core/src/main/scala/spark/scheduler/ResultTask.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index e492279b4e..2aad7956b4 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -15,9 +15,11 @@ private[spark] class ResultTask[T, U]( override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) - val result = func(context, rdd.iterator(split, context)) - context.executeOnCompleteCallbacks() - result + try { + func(context, rdd.iterator(split, context)) + } finally { + context.executeOnCompleteCallbacks() + } } override def preferredLocations: Seq[String] = locs -- cgit v1.2.3 From d228bff440395e8e6b8d67483467dde65b08ab40 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 11:48:50 -0600 Subject: Add a test. --- .../scala/spark/scheduler/TaskContextSuite.scala | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 core/src/test/scala/spark/scheduler/TaskContextSuite.scala diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala new file mode 100644 index 0000000000..f937877340 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala @@ -0,0 +1,43 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import spark.TaskContext +import spark.RDD +import spark.SparkContext +import spark.Split + +class TaskContextSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + + test("Calls executeOnCompleteCallbacks after failure") { + var completed = false + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc) { + override val splits = Array[Split](StubSplit(0)) + override val dependencies = List() + override def compute(split: Split, context: TaskContext) = { + context.addOnCompleteCallback(() => completed = true) + sys.error("failed") + } + } + val func = (c: TaskContext, i: Iterator[String]) => i.next + val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + intercept[RuntimeException] { + task.run(0) + } + assert(completed === true) + } + + case class StubSplit(val index: Int) extends Split +} \ No newline at end of file -- cgit v1.2.3 From c7143e5507f1d5292e678315158d3863c9bb4242 Mon Sep 17 00:00:00 2001 From: Andrew Psaltis Date: Tue, 15 Jan 2013 12:45:42 -0700 Subject: Changed teh scala version to 2.9.2, so that the classes can be found when the classpath is expanded. --- run2.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run2.cmd b/run2.cmd index 83464b1166..67f1e465e4 100644 --- a/run2.cmd +++ b/run2.cmd @@ -1,6 +1,6 @@ @echo off -set SCALA_VERSION=2.9.1 +set SCALA_VERSION=2.9.2 rem Figure out where the Spark framework is installed set FWDIR=%~dp0 -- cgit v1.2.3 From 4078623b9f2a338d4992c3dfd3af3a5550615180 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 15 Jan 2013 12:05:54 -0800 Subject: Remove broken attempt to test fetching case. --- .../test/scala/spark/MapOutputTrackerSuite.scala | 48 +--------------------- 1 file changed, 2 insertions(+), 46 deletions(-) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index aa1d8ac7e6..d3dd3a8fa4 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -105,52 +105,8 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - } - - test("simulatenous fetch fails") { - System.clearProperty("spark.master.host") - val dummyActorSystem = ActorSystem("testDummy") - val dummyTracker = new MapOutputTracker(dummyActorSystem, true) - dummyTracker.registerShuffle(10, 1) - // val compressedSize1000 = MapOutputTracker.compressSize(1000L) - // val size100 = MapOutputTracker.decompressSize(compressedSize1000) - // dummyTracker.registerMapOutput(10, 0, new MapStatus( - // new BlockManagerId("hostA", 1000), Array(compressedSize1000))) - val serializedMessage = dummyTracker.getSerializedLocations(10) - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) - val delayResponseLock = new java.lang.Object - val delayResponseActor = actorSystem.actorOf(Props(new Actor { - override def receive = { - case GetMapOutputStatuses(shuffleId: Int, requester: String) => - delayResponseLock.synchronized { - sender ! serializedMessage - } - } - }), name = "MapOutputTracker") - val slaveTracker = new MapOutputTracker(actorSystem, false) - var firstFailed = false - var secondFailed = false - val firstFetch = new Thread { - override def run() { - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - firstFailed = true - } - } - val secondFetch = new Thread { - override def run() { - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - secondFailed = true - } - } - delayResponseLock.synchronized { - firstFetch.start - secondFetch.start - } - firstFetch.join - secondFetch.join - assert(firstFailed && secondFailed) + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } } } -- cgit v1.2.3 From a805ac4a7cdd520b6141dd885c780c526bb54ba6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Jan 2013 10:55:26 -0800 Subject: Disabled checkpoint for PairwiseRDD (pySpark). --- core/src/main/scala/spark/api/python/PythonRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 276035a9ad..0138b22d38 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -138,6 +138,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } + override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } -- cgit v1.2.3 From eae698f755f41fd8bdff94c498df314ed74aa3c1 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 16 Jan 2013 12:21:37 -0800 Subject: remove unused thread pool --- core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 915f71ba9f..a29bf974d2 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend( with ExecutorBackend with Logging { - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) - var master: ActorRef = null override def preStart() { -- cgit v1.2.3 From 42fbef3c2a6460bcd389bb86306be3ebc14c998b Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 15:54:59 +0200 Subject: Adding default command line args to SparkALS --- .../src/main/scala/spark/examples/SparkALS.scala | 27 ++++++++++++++-------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index fb28e2c932..cbd749666d 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -7,6 +7,7 @@ import cern.jet.math._ import cern.colt.matrix._ import cern.colt.matrix.linalg._ import spark._ +import scala.Option object SparkALS { // Parameters set through command line arguments @@ -97,21 +98,27 @@ object SparkALS { def main(args: Array[String]) { var host = "" var slices = 0 - args match { - case Array(m, u, f, iters, slices_, host_) => { - M = m.toInt - U = u.toInt - F = f.toInt - ITERATIONS = iters.toInt - slices = slices_.toInt - host = host_ + + (1 to 6).map(i => { + i match { + case a if a < args.length => Option(args(a)) + case _ => Option(null) + } + }).toArray match { + case Array(host_, m, u, f, iters, slices_) => { + host = host_ getOrElse "local" + M = (m getOrElse "100").toInt + U = (u getOrElse "500").toInt + F = (f getOrElse "10").toInt + ITERATIONS = (iters getOrElse "5").toInt + slices = (slices_ getOrElse "2").toInt } case _ => { - System.err.println("Usage: SparkALS ") + System.err.println("Usage: SparkALS [ ]") System.exit(1) } } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) val spark = new SparkContext(host, "SparkALS") val R = generateR() -- cgit v1.2.3 From a512df551f85086a6ec363744542e74749c6b560 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 16:05:27 +0200 Subject: Fixed index error missing first argument --- examples/src/main/scala/spark/examples/SparkALS.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index cbd749666d..4672812565 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -99,7 +99,7 @@ object SparkALS { var host = "" var slices = 0 - (1 to 6).map(i => { + (0 to 5).map(i => { i match { case a if a < args.length => Option(args(a)) case _ => Option(null) -- cgit v1.2.3 From a5ba7a9f322dce763350864bf89d94e6656d9984 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 16:21:00 +0200 Subject: Use only one update function and pass in transpose of ratings matrix where appropriate --- .../src/main/scala/spark/examples/SparkALS.scala | 32 ++-------------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index 4672812565..2766ad1702 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -43,7 +43,7 @@ object SparkALS { return sqrt(sumSqs / (M * U)) } - def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], R: DoubleMatrix2D) : DoubleMatrix1D = { val U = us.size @@ -69,32 +69,6 @@ object SparkALS { return solved2D.viewColumn(0) } - def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val M = ms.size - val F = ms(0).size - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each movie that the user rated - for (i <- 0 until M) { - val m = ms(i) - // Add m * m^t to XtX - blas.dger(1, m, m, XtX) - // Add m * rating to Xty - blas.daxpy(R.get(i, j), m, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - def main(args: Array[String]) { var host = "" var slices = 0 @@ -134,11 +108,11 @@ object SparkALS { for (iter <- 1 to ITERATIONS) { println("Iteration " + iter + ":") ms = spark.parallelize(0 until M, slices) - .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value)) + .map(i => update(i, msc.value(i), usc.value, Rc.value)) .toArray msc = spark.broadcast(ms) // Re-broadcast ms because it was updated us = spark.parallelize(0 until U, slices) - .map(i => updateUser(i, usc.value(i), msc.value, Rc.value)) + .map(i => update(i, usc.value(i), msc.value, algebra.transpose(Rc.value))) .toArray usc = spark.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) -- cgit v1.2.3 From 892c32a14b89139b7bd89e141fc90b148a67ce68 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 17 Jan 2013 11:14:47 -0800 Subject: Warn users if they run pyspark or spark-shell without compiling Spark --- pyspark | 7 +++++++ run | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/pyspark b/pyspark index 9e89d51ba2..ab7f4f50c0 100755 --- a/pyspark +++ b/pyspark @@ -6,6 +6,13 @@ FWDIR="$(cd `dirname $0`; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +# Exit if the user hasn't compiled Spark +if [ ! -e "$SPARK_HOME/repl/target" ]; then + echo "Failed to find Spark classes in $SPARK_HOME/repl/target" >&2 + echo "You need to compile Spark before running this program" >&2 + exit 1 +fi + # Load environment variables from conf/spark-env.sh, if it exists if [ -e $FWDIR/conf/spark-env.sh ] ; then . $FWDIR/conf/spark-env.sh diff --git a/run b/run index ca23455386..eb93db66db 100755 --- a/run +++ b/run @@ -65,6 +65,13 @@ EXAMPLES_DIR="$FWDIR/examples" BAGEL_DIR="$FWDIR/bagel" PYSPARK_DIR="$FWDIR/python" +# Exit if the user hasn't compiled Spark +if [ ! -e "$REPL_DIR/target" ]; then + echo "Failed to find Spark classes in $REPL_DIR/target" >&2 + echo "You need to compile Spark before running this program" >&2 + exit 1 +fi + # Build up classpath CLASSPATH="$SPARK_CLASSPATH" CLASSPATH+=":$FWDIR/conf" -- cgit v1.2.3 From 742bc841adb2a57b05e7a155681a162ab9dfa2c1 Mon Sep 17 00:00:00 2001 From: Fernand Pajot Date: Thu, 17 Jan 2013 16:56:11 -0800 Subject: changed HttpBroadcast server cache to be in spark.local.dir instead of java.io.tmpdir --- core/src/main/scala/spark/broadcast/HttpBroadcast.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 7eb4ddb74f..96dc28f12a 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -89,7 +89,7 @@ private object HttpBroadcast extends Logging { } private def createServer() { - broadcastDir = Utils.createTempDir() + broadcastDir = Utils.createTempDir(System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri -- cgit v1.2.3 From 54c0f9f185576e9b844fa8f81ca410f188daa51c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 17 Jan 2013 17:40:55 -0800 Subject: Fix code that assumed spark.local.dir is only a single directory --- core/src/main/scala/spark/Utils.scala | 11 ++++++++++- core/src/main/scala/spark/broadcast/HttpBroadcast.scala | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 0e7007459d..aeed5d2f32 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -134,7 +134,7 @@ private object Utils extends Logging { */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last - val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")) + val tempDir = getLocalDir val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) @@ -204,6 +204,15 @@ private object Utils extends Logging { FileUtil.chmod(filename, "a+x") } + /** + * Get a temporary directory using Spark's spark.local.dir property, if set. This will always + * return a single directory, even though the spark.local.dir property might be a list of + * multiple paths. + */ + def getLocalDir: String = { + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + } + /** * Shuffle the elements of a collection into a random order, returning the * result in a new collection. Unlike scala.util.Random.shuffle, this method diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 96dc28f12a..856a4683a9 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -89,7 +89,7 @@ private object HttpBroadcast extends Logging { } private def createServer() { - broadcastDir = Utils.createTempDir(System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + broadcastDir = Utils.createTempDir(Utils.getLocalDir) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri -- cgit v1.2.3 From 2a872335c5c7b5481c927272447e4a344ef59dda Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 16 Jan 2013 13:17:45 -0800 Subject: Bug fix and test cleanup --- .../spark/streaming/api/java/JavaPairDStream.scala | 4 ++-- .../src/test/scala/spark/streaming/JavaAPISuite.java | 17 ++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 49a0f27b5b..1c5b864ff0 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -433,7 +433,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val cm: ClassManifest[S] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[S]] - def scalaFunc(values: Seq[V], state: Option[S]): Option[S] = { + val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { val list: JList[V] = values val scalaState: Optional[S] = state match { case Some(s) => Optional.of(s) @@ -445,7 +445,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( case _ => None } } - dstream.updateStateByKey(scalaFunc _) + dstream.updateStateByKey(scalaFunc) } def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index d95ab485f8..549fb5b733 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -52,15 +52,15 @@ public class JavaAPISuite implements Serializable { Arrays.asList(3,4,5), Arrays.asList(3)); - List> expected = Arrays.asList( - Arrays.asList(4), - Arrays.asList(3), - Arrays.asList(1)); + List> expected = Arrays.asList( + Arrays.asList(4L), + Arrays.asList(3L), + Arrays.asList(1L)); JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); JavaDStream count = stream.count(); JavaTestUtils.attachTestOutputStream(count); - List> result = JavaTestUtils.runStreams(sc, 3, 3); + List> result = JavaTestUtils.runStreams(sc, 3, 3); assertOrderInvariantEquals(expected, result); } @@ -561,8 +561,8 @@ public class JavaAPISuite implements Serializable { new Tuple2("new york", 5)), Arrays.asList(new Tuple2("california", 14), new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -572,6 +572,9 @@ public class JavaAPISuite implements Serializable { @Override public Optional call(List values, Optional state) { int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } for (Integer v: values) { out = out + v; } -- cgit v1.2.3 From 8e6cbbc6c7434b53c63e19a1c9c2dca1f24de654 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 16 Jan 2013 13:50:02 -0800 Subject: Adding other updateState functions --- .../spark/streaming/api/java/JavaPairDStream.scala | 62 +++++++++++++++++----- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 1c5b864ff0..8c76d8c1d8 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -420,6 +420,23 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) } + private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): + (Seq[V], Option[S]) => Option[S] = { + val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { + val list: JList[V] = values + val scalaState: Optional[S] = state match { + case Some(s) => Optional.of(s) + case _ => Optional.absent() + } + val result: Optional[S] = in.apply(list, scalaState) + result.isPresent match { + case true => Some(result.get()) + case _ => None + } + } + scalaFunc + } + /** * Create a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. @@ -432,20 +449,39 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( : JavaPairDStream[K, S] = { implicit val cm: ClassManifest[S] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[S]] + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc)) + } - val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { - val list: JList[V] = values - val scalaState: Optional[S] = state match { - case Some(s) => Optional.of(s) - case _ => Optional.absent() - } - val result: Optional[S] = updateFunc.apply(list, scalaState) - result.isPresent match { - case true => Some(result.get()) - case _ => None - } - } - dstream.updateStateByKey(scalaFunc) + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param numPartitions Number of partitions of each RDD in the new DStream. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], + numPartitions: Int) + : JavaPairDStream[K, S] = { + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), numPartitions) + } + + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key. + * [[spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], + partitioner: Partitioner + ): JavaPairDStream[K, S] = { + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner) } def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { -- cgit v1.2.3 From d5570c7968baba1c1fe86c68dc1c388fae23907b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 16 Jan 2013 21:07:49 -0800 Subject: Adding checkpointing to Java API --- .../main/scala/spark/api/java/JavaRDDLike.scala | 28 ++++++++++++++++++++++ .../scala/spark/api/java/JavaSparkContext.scala | 26 ++++++++++++++++++++ core/src/test/scala/spark/JavaAPISuite.java | 27 +++++++++++++++++++++ 3 files changed, 81 insertions(+) diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 81d3a94466..958f5c26a1 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -9,6 +9,7 @@ import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import spark.partial.{PartialResult, BoundedDouble} import spark.storage.StorageLevel +import com.google.common.base.Optional trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @@ -298,4 +299,31 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Save this RDD as a SequenceFile of serialized objects. */ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) + + /** + * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` + * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. + * This is used to truncate very long lineages. In the current implementation, Spark will save + * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. + * Hence, it is strongly recommended to use checkpoint() on RDDs when + * (i) checkpoint() is called before the any job has been executed on this RDD. + * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will + * require recomputation. + */ + def checkpoint() = rdd.checkpoint() + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed(): Boolean = rdd.isCheckpointed() + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile(): Optional[String] = { + rdd.getCheckpointFile match { + case Some(file) => Optional.of(file) + case _ => Optional.absent() + } + } } diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index bf9ad7a200..22bfa2280d 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -342,6 +342,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def clearFiles() { sc.clearFiles() } + + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String, useExisting: Boolean) { + sc.setCheckpointDir(dir, useExisting) + } + + /** + * Set the directory under which RDDs are going to be checkpointed. This method will + * create this directory and will throw an exception of the path already exists (to avoid + * overwriting existing files may be overwritten). The directory will be deleted on exit + * if indicated. + */ + def setCheckpointDir(dir: String) { + sc.setCheckpointDir(dir) + } + + protected def checkpointFile[T](path: String): JavaRDD[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + new JavaRDD(sc.checkpointFile(path)) + } } object JavaSparkContext { diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index b99e790093..0b5354774b 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -625,4 +625,31 @@ public class JavaAPISuite implements Serializable { }); Assert.assertEquals((Float) 25.0f, floatAccum.value()); } + + @Test + public void checkpointAndComputation() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); + } + + @Test + public void checkpointAndRestore() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + + Assert.assertTrue(rdd.getCheckpointFile().isPresent()); + JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); + } } -- cgit v1.2.3 From 61b877c688d6c9a4e9c4d8f22ca0cadae29895bb Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 17 Jan 2013 09:04:56 -0800 Subject: Adding flatMap --- .../spark/streaming/api/java/JavaDStreamLike.scala | 29 ++++++++ .../test/scala/spark/streaming/JavaAPISuite.java | 81 ++++++++++++++++++++-- 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index 32df665a98..b93cb7865a 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -65,6 +65,27 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable } /** + * Return a new DStream by applying a function to all elements of this DStream, + * and then flattening the results + */ + def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.apply(x).asScala + new JavaDStream(dstream.flatMap(fn)(f.elementType()))(f.elementType()) + } + + /** + * Return a new DStream by applying a function to all elements of this DStream, + * and then flattening the results + */ + def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairDStream[K, V] = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.apply(x).asScala + def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] + new JavaPairDStream(dstream.flatMap(fn)(cm))(f.keyType(), f.valueType()) + } + + /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition * of the RDD. @@ -151,4 +172,12 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable transformFunc.call(new JavaRDD[T](in), time).rdd dstream.transform(scalaTransform(_, _)) } + + /** + * Enable periodic checkpointing of RDDs of this DStream + * @param interval Time interval after which generated RDD will be checkpointed + */ + def checkpoint(interval: Duration) = { + dstream.checkpoint(interval) + } } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 549fb5b733..41fd9f99ff 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -11,10 +11,7 @@ import org.junit.Test; import scala.Tuple2; import spark.HashPartitioner; import spark.api.java.JavaRDD; -import spark.api.java.function.FlatMapFunction; -import spark.api.java.function.Function; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFunction; +import spark.api.java.function.*; import spark.storage.StorageLevel; import spark.streaming.api.java.JavaDStream; import spark.streaming.api.java.JavaPairDStream; @@ -308,6 +305,82 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expected, result); } + @Test + public void testFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("go", "giants"), + Arrays.asList("boo", "dodgers"), + Arrays.asList("athletics")); + + List> expected = Arrays.asList( + Arrays.asList("g","o","g","i","a","n","t","s"), + Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"), + Arrays.asList("a","t","h","l","e","t","i","c","s")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(x.split("(?!^)")); + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List> result = JavaTestUtils.runStreams(sc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testPairFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("dodgers"), + Arrays.asList("athletics")); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(6, "g"), + new Tuple2(6, "i"), + new Tuple2(6, "a"), + new Tuple2(6, "n"), + new Tuple2(6, "t"), + new Tuple2(6, "s")), + Arrays.asList( + new Tuple2(7, "d"), + new Tuple2(7, "o"), + new Tuple2(7, "d"), + new Tuple2(7, "g"), + new Tuple2(7, "e"), + new Tuple2(7, "r"), + new Tuple2(7, "s")), + Arrays.asList( + new Tuple2(9, "a"), + new Tuple2(9, "t"), + new Tuple2(9, "h"), + new Tuple2(9, "l"), + new Tuple2(9, "e"), + new Tuple2(9, "t"), + new Tuple2(9, "i"), + new Tuple2(9, "c"), + new Tuple2(9, "s"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction() { + @Override + public Iterable> call(String in) throws Exception { + List> out = Lists.newArrayList(); + for (String letter: in.split("(?!^)")) { + out.add(new Tuple2(in.length(), letter)); + } + return out; + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(sc, 3, 3); + + Assert.assertEquals(expected, result); + } + @Test public void testUnion() { List> inputData1 = Arrays.asList( -- cgit v1.2.3 From 82b8707c6bbb3926e59c241b6e6d5ead5467aae7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 17 Jan 2013 12:21:44 -0800 Subject: Checkpointing in Streaming java API --- .../streaming/api/java/JavaStreamingContext.scala | 19 +++++- streaming/src/test/scala/JavaTestUtils.scala | 9 ++- .../test/scala/spark/streaming/JavaAPISuite.java | 68 ++++++++++++++++++++++ 3 files changed, 93 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 2833793b94..ebbc516b38 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -13,14 +13,29 @@ import java.io.InputStream import java.util.{Map => JMap} class JavaStreamingContext(val ssc: StreamingContext) { - def this(master: String, frameworkName: String, batchDuration: Duration) = - this(new StreamingContext(master, frameworkName, batchDuration)) // TODOs: // - Test StreamingContext functions // - Test to/from Hadoop functions // - Support creating and registering InputStreams + + /** + * Creates a StreamingContext. + * @param master Name of the Spark Master + * @param frameworkName Name to be used when registering with the scheduler + * @param batchDuration The time interval at which streaming data will be divided into batches + */ + def this(master: String, frameworkName: String, batchDuration: Duration) = + this(new StreamingContext(master, frameworkName, batchDuration)) + + /** + * Re-creates a StreamingContext from a checkpoint file. + * @param path Path either to the directory that was specified as the checkpoint directory, or + * to the checkpoint file 'graph' or 'graph.bk'. + */ + def this(path: String) = this (new StreamingContext(path)) + /** * Create an input stream that pulls messages form a Kafka Broker. * @param hostname Zookeper hostname. diff --git a/streaming/src/test/scala/JavaTestUtils.scala b/streaming/src/test/scala/JavaTestUtils.scala index 24ebc15e38..56349837e5 100644 --- a/streaming/src/test/scala/JavaTestUtils.scala +++ b/streaming/src/test/scala/JavaTestUtils.scala @@ -8,7 +8,7 @@ import java.util.ArrayList import collection.JavaConversions._ /** Exposes streaming test functionality in a Java-friendly way. */ -object JavaTestUtils extends TestSuiteBase { +trait JavaTestBase extends TestSuiteBase { /** * Create a [[spark.streaming.TestInputStream]] and attach it to the supplied context. @@ -56,3 +56,10 @@ object JavaTestUtils extends TestSuiteBase { } } +object JavaTestUtils extends JavaTestBase { + +} + +object JavaCheckpointTestUtils extends JavaTestBase { + override def actuallyWait = true +} \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 41fd9f99ff..8a63e8cd3f 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -3,6 +3,7 @@ package spark.streaming; import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.io.Files; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.junit.After; import org.junit.Assert; @@ -17,6 +18,7 @@ import spark.streaming.api.java.JavaDStream; import spark.streaming.api.java.JavaPairDStream; import spark.streaming.api.java.JavaStreamingContext; import spark.streaming.JavaTestUtils; +import spark.streaming.JavaCheckpointTestUtils; import spark.streaming.dstream.KafkaPartitionKey; import sun.org.mozilla.javascript.annotations.JSFunction; @@ -871,6 +873,72 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, result); } + @Test + public void testCheckpointMasterRecovery() throws InterruptedException { + List> inputData = Arrays.asList( + Arrays.asList("this", "is"), + Arrays.asList("a", "test"), + Arrays.asList("counting", "letters")); + + List> expectedInitial = Arrays.asList( + Arrays.asList(4,2)); + List> expectedFinal = Arrays.asList( + Arrays.asList(1,4), + Arrays.asList(8,7)); + + + File tempDir = Files.createTempDir(); + sc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); + + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaCheckpointTestUtils.attachTestOutputStream(letterCount); + List> initialResult = JavaTestUtils.runStreams(sc, 1, 1); + + assertOrderInvariantEquals(expectedInitial, initialResult); + Thread.sleep(1000); + + sc.stop(); + sc = new JavaStreamingContext(tempDir.getAbsolutePath()); + sc.start(); + List> finalResult = JavaCheckpointTestUtils.runStreams(sc, 2, 2); + assertOrderInvariantEquals(expectedFinal, finalResult); + } + + /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD + @Test + public void testCheckpointofIndividualStream() throws InterruptedException { + List> inputData = Arrays.asList( + Arrays.asList("this", "is"), + Arrays.asList("a", "test"), + Arrays.asList("counting", "letters")); + + List> expected = Arrays.asList( + Arrays.asList(4,2), + Arrays.asList(1,4), + Arrays.asList(8,7)); + + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaCheckpointTestUtils.attachTestOutputStream(letterCount); + + letterCount.checkpoint(new Duration(1000)); + + List> result1 = JavaCheckpointTestUtils.runStreams(sc, 3, 3); + assertOrderInvariantEquals(expected, result1); + } + */ + // Input stream tests. These mostly just test that we can instantiate a given InputStream with // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the // InputStream functionality is deferred to the existing Scala tests. -- cgit v1.2.3 From 2261e62ee52495599b7a8717884e878497d343ea Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 17 Jan 2013 18:35:24 -0800 Subject: Style cleanup --- .../spark/streaming/api/java/JavaDStream.scala | 22 +++++++++++++++++++++- .../spark/streaming/api/java/JavaPairDStream.scala | 1 - .../streaming/api/java/JavaStreamingContext.scala | 6 +++++- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index 32faef5670..e21e54d3e5 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -6,6 +6,26 @@ import spark.api.java.JavaRDD import java.util.{List => JList} import spark.storage.StorageLevel +/** + * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous + * sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]] + * for more details on RDDs). DStreams can either be created from live data (such as, data from + * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations + * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each + * DStream periodically generates a RDD, either from live data or by transforming the RDD generated + * by a parent DStream. + * + * This class contains the basic operations available on all DStreams, such as `map`, `filter` and + * `window`. In addition, [[spark.streaming.api.java.JavaPairDStream]] contains operations available + * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations + * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through + * implicit conversions when `spark.streaming.StreamingContext._` is imported. + * + * DStreams internally is characterized by a few basic properties: + * - A list of other DStreams that the DStream depends on + * - A time interval at which the DStream generates an RDD + * - A function that is used to generate an RDD after each time interval + */ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) extends JavaDStreamLike[T, JavaDStream[T]] { @@ -69,4 +89,4 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM object JavaDStream { implicit def fromDStream[T: ClassManifest](dstream: DStream[T]): JavaDStream[T] = new JavaDStream[T](dstream) -} +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 8c76d8c1d8..ef10c091ca 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -14,7 +14,6 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.hadoop.conf.Configuration import spark.api.java.JavaPairRDD import spark.storage.StorageLevel -import java.lang import com.google.common.base.Optional class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index ebbc516b38..7e1c2a999f 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -12,10 +12,14 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import java.io.InputStream import java.util.{Map => JMap} +/** + * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic + * information (such as, cluster URL and job name) to internally create a SparkContext, it provides + * methods used to create DStream from various input sources. + */ class JavaStreamingContext(val ssc: StreamingContext) { // TODOs: - // - Test StreamingContext functions // - Test to/from Hadoop functions // - Support creating and registering InputStreams -- cgit v1.2.3 From 70ba994d6d6f9e62269168e6a8a61ffce736a4d2 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 17 Jan 2013 18:38:30 -0800 Subject: Import fixup --- streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala | 1 - .../src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala | 1 - streaming/src/test/scala/spark/streaming/JavaAPISuite.java | 1 - 3 files changed, 3 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index e21e54d3e5..2e7466b16c 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -3,7 +3,6 @@ package spark.streaming.api.java import spark.streaming.{Duration, Time, DStream} import spark.api.java.function.{Function => JFunction} import spark.api.java.JavaRDD -import java.util.{List => JList} import spark.storage.StorageLevel /** diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 7e1c2a999f..accac82e09 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -1,7 +1,6 @@ package spark.streaming.api.java import scala.collection.JavaConversions._ -import java.util.{List => JList} import java.lang.{Long => JLong, Integer => JInt} import spark.streaming._ diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 8a63e8cd3f..374793b57e 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -20,7 +20,6 @@ import spark.streaming.api.java.JavaStreamingContext; import spark.streaming.JavaTestUtils; import spark.streaming.JavaCheckpointTestUtils; import spark.streaming.dstream.KafkaPartitionKey; -import sun.org.mozilla.javascript.annotations.JSFunction; import java.io.*; import java.util.*; -- cgit v1.2.3 From 6fba7683c29b64a33a6daa28cc56bb5d20574314 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 17 Jan 2013 18:46:24 -0800 Subject: Small doc fix --- .../src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java index 151b71eb81..cddce16e39 100644 --- a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java +++ b/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java @@ -12,7 +12,7 @@ import spark.streaming.dstream.SparkFlumeEvent; * an Avro server on at the request host:port address and listen for requests. * Your Flume AvroSink should be pointed to this address. * - * Usage: FlumeEventCount + * Usage: JavaFlumeEventCount * * is a Spark master URL * is the host the Flume receiver will be started on - a receiver -- cgit v1.2.3 From e0165bf7141086e28f88cd68ab7bc6249061c924 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 17 Jan 2013 21:07:09 -0800 Subject: Adding queueStream and some slight refactoring --- .../spark/streaming/examples/JavaQueueStream.java | 62 +++++++ project/SparkBuild.scala | 3 +- .../streaming/api/java/JavaStreamingContext.scala | 58 +++++++ .../test/scala/spark/streaming/JavaAPISuite.java | 186 ++++++++++++--------- 4 files changed, 226 insertions(+), 83 deletions(-) create mode 100644 examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java diff --git a/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java new file mode 100644 index 0000000000..43c3cd4dfa --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java @@ -0,0 +1,62 @@ +package spark.streaming.examples; + +import com.google.common.collect.Lists; +import scala.Tuple2; +import spark.api.java.JavaRDD; +import spark.api.java.function.Function2; +import spark.api.java.function.PairFunction; +import spark.streaming.Duration; +import spark.streaming.api.java.JavaDStream; +import spark.streaming.api.java.JavaPairDStream; +import spark.streaming.api.java.JavaStreamingContext; + +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; + +public class JavaQueueStream { + public static void main(String[] args) throws InterruptedException { + if (args.length < 1) { + System.err.println("Usage: JavaQueueStream "); + System.exit(1); + } + + // Create the context + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000)); + + // Create the queue through which RDDs can be pushed to + // a QueueInputDStream + Queue> rddQueue = new LinkedList>(); + + // Create and push some RDDs into the queue + List list = Lists.newArrayList(); + for (int i = 0; i < 1000; i++) { + list.add(i); + } + + for (int i = 0; i < 30; i++) { + rddQueue.add(ssc.sc().parallelize(list)); + } + + + // Create the QueueInputDStream and use it do some processing + JavaDStream inputStream = ssc.queueStream(rddQueue); + JavaPairDStream mappedStream = inputStream.map( + new PairFunction() { + @Override + public Tuple2 call(Integer i) throws Exception { + return new Tuple2(i % 10, 1); + } + }); + JavaPairDStream reducedStream = mappedStream.reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }); + + reducedStream.print(); + ssc.start(); + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 39db4be842..d5cda347a4 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -162,8 +162,7 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", libraryDependencies ++= Seq( - "com.github.sgroschupf" % "zkclient" % "0.1", - "junit" % "junit" % "4.8.1") + "com.github.sgroschupf" % "zkclient" % "0.1") ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index accac82e09..f82e6a37cc 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -10,6 +10,7 @@ import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import java.io.InputStream import java.util.{Map => JMap} +import spark.api.java.{JavaSparkContext, JavaRDD} /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -39,6 +40,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { */ def this(path: String) = this (new StreamingContext(path)) + /** The underlying SparkContext */ + val sc: JavaSparkContext = new JavaSparkContext(ssc.sc) + /** * Create an input stream that pulls messages form a Kafka Broker. * @param hostname Zookeper hostname. @@ -254,6 +258,60 @@ class JavaStreamingContext(val ssc: StreamingContext) { ssc.registerOutputStream(outputStream.dstream) } + /** + * Creates a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: changes to the queue after the stream is created will not be recognized. + * @param queue Queue of RDDs + * @tparam T Type of objects in the RDD + */ + def queueStream[T](queue: java.util.Queue[JavaRDD[T]]): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val sQueue = new scala.collection.mutable.Queue[spark.RDD[T]] + sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + ssc.queueStream(sQueue) + } + + /** + * Creates a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: changes to the queue after the stream is created will not be recognized. + * @param queue Queue of RDDs + * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval + * @tparam T Type of objects in the RDD + */ + def queueStream[T](queue: java.util.Queue[JavaRDD[T]], oneAtATime: Boolean): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val sQueue = new scala.collection.mutable.Queue[spark.RDD[T]] + sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + ssc.queueStream(sQueue, oneAtATime) + } + + /** + * Creates a input stream from an queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: changes to the queue after the stream is created will not be recognized. + * @param queue Queue of RDDs + * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval + * @param defaultRDD Default RDD is returned by the DStream when the queue is empty + * @tparam T Type of objects in the RDD + */ + def queueStream[T]( + queue: java.util.Queue[JavaRDD[T]], + oneAtATime: Boolean, + defaultRDD: JavaRDD[T]): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val sQueue = new scala.collection.mutable.Queue[spark.RDD[T]] + sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd) + } + /** * Sets the context to periodically checkpoint the DStream operations for master * fault-tolerance. By default, the graph will be checkpointed every batch interval. diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 374793b57e..8c94e13e65 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -12,6 +12,7 @@ import org.junit.Test; import scala.Tuple2; import spark.HashPartitioner; import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; import spark.api.java.function.*; import spark.storage.StorageLevel; import spark.streaming.api.java.JavaDStream; @@ -28,17 +29,17 @@ import java.util.*; // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite implements Serializable { - private transient JavaStreamingContext sc; + private transient JavaStreamingContext ssc; @Before public void setUp() { - sc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); } @After public void tearDown() { - sc.stop(); - sc = null; + ssc.stop(); + ssc = null; // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port"); } @@ -55,10 +56,10 @@ public class JavaAPISuite implements Serializable { Arrays.asList(3L), Arrays.asList(1L)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream count = stream.count(); JavaTestUtils.attachTestOutputStream(count); - List> result = JavaTestUtils.runStreams(sc, 3, 3); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); assertOrderInvariantEquals(expected, result); } @@ -72,7 +73,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList(5,5), Arrays.asList(9,4)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override public Integer call(String s) throws Exception { @@ -80,7 +81,7 @@ public class JavaAPISuite implements Serializable { } }); JavaTestUtils.attachTestOutputStream(letterCount); - List> result = JavaTestUtils.runStreams(sc, 2, 2); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); assertOrderInvariantEquals(expected, result); } @@ -98,10 +99,10 @@ public class JavaAPISuite implements Serializable { Arrays.asList(7,8,9,4,5,6), Arrays.asList(7,8,9)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream windowed = stream.window(new Duration(2000)); JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(sc, 4, 4); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); assertOrderInvariantEquals(expected, result); } @@ -122,10 +123,10 @@ public class JavaAPISuite implements Serializable { Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18), Arrays.asList(13,14,15,16,17,18)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000)); JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(sc, 8, 4); + List> result = JavaTestUtils.runStreams(ssc, 8, 4); assertOrderInvariantEquals(expected, result); } @@ -145,10 +146,10 @@ public class JavaAPISuite implements Serializable { Arrays.asList(7,8,9,10,11,12), Arrays.asList(13,14,15,16,17,18)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream windowed = stream.tumble(new Duration(2000)); JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(sc, 6, 3); + List> result = JavaTestUtils.runStreams(ssc, 6, 3); assertOrderInvariantEquals(expected, result); } @@ -163,7 +164,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList("giants"), Arrays.asList("yankees")); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream filtered = stream.filter(new Function() { @Override public Boolean call(String s) throws Exception { @@ -171,7 +172,7 @@ public class JavaAPISuite implements Serializable { } }); JavaTestUtils.attachTestOutputStream(filtered); - List> result = JavaTestUtils.runStreams(sc, 2, 2); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); assertOrderInvariantEquals(expected, result); } @@ -186,10 +187,10 @@ public class JavaAPISuite implements Serializable { Arrays.asList(Arrays.asList("giants", "dodgers")), Arrays.asList(Arrays.asList("yankees", "red socks"))); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream glommed = stream.glom(); JavaTestUtils.attachTestOutputStream(glommed); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -204,7 +205,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList("GIANTSDODGERS"), Arrays.asList("YANKEESRED SOCKS")); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream mapped = stream.mapPartitions(new FlatMapFunction, String>() { @Override public Iterable call(Iterator in) { @@ -216,7 +217,7 @@ public class JavaAPISuite implements Serializable { } }); JavaTestUtils.attachTestOutputStream(mapped); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -247,10 +248,10 @@ public class JavaAPISuite implements Serializable { Arrays.asList(15), Arrays.asList(24)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream reduced = stream.reduce(new IntegerSum()); JavaTestUtils.attachTestOutputStream(reduced); - List> result = JavaTestUtils.runStreams(sc, 3, 3); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } @@ -268,15 +269,38 @@ public class JavaAPISuite implements Serializable { Arrays.asList(39), Arrays.asList(24)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reducedWindowed); - List> result = JavaTestUtils.runStreams(sc, 4, 4); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); Assert.assertEquals(expected, result); } + @Test + public void testQueueStream() { + List> expected = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); + JavaRDD rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3)); + JavaRDD rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6)); + JavaRDD rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9)); + + LinkedList> rdds = Lists.newLinkedList(); + rdds.add(rdd1); + rdds.add(rdd2); + rdds.add(rdd3); + + JavaDStream stream = ssc.queueStream(rdds); + JavaTestUtils.attachTestOutputStream(stream); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + Assert.assertEquals(expected, result); + } + @Test public void testTransform() { List> inputData = Arrays.asList( @@ -289,7 +313,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList(6,7,8), Arrays.asList(9,10,11)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream transformed = stream.transform(new Function, JavaRDD>() { @Override public JavaRDD call(JavaRDD in) throws Exception { @@ -301,7 +325,7 @@ public class JavaAPISuite implements Serializable { }); }}); JavaTestUtils.attachTestOutputStream(transformed); - List> result = JavaTestUtils.runStreams(sc, 3, 3); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); assertOrderInvariantEquals(expected, result); } @@ -318,7 +342,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"), Arrays.asList("a","t","h","l","e","t","i","c","s")); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { @@ -326,7 +350,7 @@ public class JavaAPISuite implements Serializable { } }); JavaTestUtils.attachTestOutputStream(flatMapped); - List> result = JavaTestUtils.runStreams(sc, 3, 3); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); assertOrderInvariantEquals(expected, result); } @@ -365,7 +389,7 @@ public class JavaAPISuite implements Serializable { new Tuple2(9, "c"), new Tuple2(9, "s"))); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction() { @Override public Iterable> call(String in) throws Exception { @@ -377,7 +401,7 @@ public class JavaAPISuite implements Serializable { } }); JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(sc, 3, 3); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } @@ -399,12 +423,12 @@ public class JavaAPISuite implements Serializable { Arrays.asList(2,2,5,5), Arrays.asList(3,3,6,6)); - JavaDStream stream1 = JavaTestUtils.attachTestInputStream(sc, inputData1, 2); - JavaDStream stream2 = JavaTestUtils.attachTestInputStream(sc, inputData2, 2); + JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2); + JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2); JavaDStream unioned = stream1.union(stream2); JavaTestUtils.attachTestOutputStream(unioned); - List> result = JavaTestUtils.runStreams(sc, 3, 3); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); assertOrderInvariantEquals(expected, result); } @@ -436,7 +460,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList(new Tuple2("giants", 6)), Arrays.asList(new Tuple2("yankees", 7))); - JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = stream.map( new PairFunction() { @Override @@ -453,7 +477,7 @@ public class JavaAPISuite implements Serializable { } }); JavaTestUtils.attachTestOutputStream(filtered); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -492,12 +516,12 @@ public class JavaAPISuite implements Serializable { new Tuple2>("california", Arrays.asList("sharks", "ducks")), new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream> grouped = pairStream.groupByKey(); JavaTestUtils.attachTestOutputStream(grouped); - List>>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -515,13 +539,13 @@ public class JavaAPISuite implements Serializable { new Tuple2("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( - sc, inputData, 1); + ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduced = pairStream.reduceByKey(new IntegerSum()); JavaTestUtils.attachTestOutputStream(reduced); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -539,7 +563,7 @@ public class JavaAPISuite implements Serializable { new Tuple2("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( - sc, inputData, 1); + ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream combined = pairStream.combineByKey( @@ -551,7 +575,7 @@ public class JavaAPISuite implements Serializable { }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); JavaTestUtils.attachTestOutputStream(combined); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -569,12 +593,12 @@ public class JavaAPISuite implements Serializable { new Tuple2("new york", 2L))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( - sc, inputData, 1); + ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream counted = pairStream.countByKey(); JavaTestUtils.attachTestOutputStream(counted); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -592,13 +616,13 @@ public class JavaAPISuite implements Serializable { Arrays.asList(new Tuple2>("california", Arrays.asList("sharks", "ducks")), new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream> groupWindowed = pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(groupWindowed); - List>>> result = JavaTestUtils.runStreams(sc, 3, 3); + List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } @@ -615,13 +639,13 @@ public class JavaAPISuite implements Serializable { Arrays.asList(new Tuple2("california", 10), new Tuple2("new york", 4))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(sc, 3, 3); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } @@ -638,7 +662,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList(new Tuple2("california", 14), new Tuple2("new york", 9))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream updated = pairStream.updateStateByKey( @@ -656,7 +680,7 @@ public class JavaAPISuite implements Serializable { } }); JavaTestUtils.attachTestOutputStream(updated); - List>> result = JavaTestUtils.runStreams(sc, 3, 3); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } @@ -673,13 +697,13 @@ public class JavaAPISuite implements Serializable { Arrays.asList(new Tuple2("california", 10), new Tuple2("new york", 4))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(sc, 3, 3); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } @@ -700,13 +724,13 @@ public class JavaAPISuite implements Serializable { new Tuple2("new york", 2L))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( - sc, inputData, 1); + ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream counted = pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); - List>> result = JavaTestUtils.runStreams(sc, 3, 3); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } @@ -726,7 +750,7 @@ public class JavaAPISuite implements Serializable { new Tuple2("new york", "ISLANDERS"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( - sc, inputData, 1); + ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream mapped = pairStream.mapValues(new Function() { @@ -737,7 +761,7 @@ public class JavaAPISuite implements Serializable { }); JavaTestUtils.attachTestOutputStream(mapped); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -765,7 +789,7 @@ public class JavaAPISuite implements Serializable { new Tuple2("new york", "islanders2"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( - sc, inputData, 1); + ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -781,7 +805,7 @@ public class JavaAPISuite implements Serializable { }); JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -815,16 +839,16 @@ public class JavaAPISuite implements Serializable { JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( - sc, stringStringKVStream1, 1); + ssc, stringStringKVStream1, 1); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( - sc, stringStringKVStream2, 1); + ssc, stringStringKVStream2, 1); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); JavaPairDStream, List>> grouped = pairStream1.cogroup(pairStream2); JavaTestUtils.attachTestOutputStream(grouped); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -858,16 +882,16 @@ public class JavaAPISuite implements Serializable { JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( - sc, stringStringKVStream1, 1); + ssc, stringStringKVStream1, 1); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( - sc, stringStringKVStream2, 1); + ssc, stringStringKVStream2, 1); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); JavaPairDStream> joined = pairStream1.join(pairStream2); JavaTestUtils.attachTestOutputStream(joined); - List>> result = JavaTestUtils.runStreams(sc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -887,9 +911,9 @@ public class JavaAPISuite implements Serializable { File tempDir = Files.createTempDir(); - sc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); + ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); - JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override public Integer call(String s) throws Exception { @@ -897,15 +921,15 @@ public class JavaAPISuite implements Serializable { } }); JavaCheckpointTestUtils.attachTestOutputStream(letterCount); - List> initialResult = JavaTestUtils.runStreams(sc, 1, 1); + List> initialResult = JavaTestUtils.runStreams(ssc, 1, 1); assertOrderInvariantEquals(expectedInitial, initialResult); Thread.sleep(1000); - sc.stop(); - sc = new JavaStreamingContext(tempDir.getAbsolutePath()); - sc.start(); - List> finalResult = JavaCheckpointTestUtils.runStreams(sc, 2, 2); + ssc.stop(); + ssc = new JavaStreamingContext(tempDir.getAbsolutePath()); + ssc.start(); + List> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 2); assertOrderInvariantEquals(expectedFinal, finalResult); } @@ -922,7 +946,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList(1,4), Arrays.asList(8,7)); - JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(sc, inputData, 1); + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override public Integer call(String s) throws Exception { @@ -933,7 +957,7 @@ public class JavaAPISuite implements Serializable { letterCount.checkpoint(new Duration(1000)); - List> result1 = JavaCheckpointTestUtils.runStreams(sc, 3, 3); + List> result1 = JavaCheckpointTestUtils.runStreams(ssc, 3, 3); assertOrderInvariantEquals(expected, result1); } */ @@ -945,15 +969,15 @@ public class JavaAPISuite implements Serializable { public void testKafkaStream() { HashMap topics = Maps.newHashMap(); HashMap offsets = Maps.newHashMap(); - JavaDStream test1 = sc.kafkaStream("localhost", 12345, "group", topics); - JavaDStream test2 = sc.kafkaStream("localhost", 12345, "group", topics, offsets); - JavaDStream test3 = sc.kafkaStream("localhost", 12345, "group", topics, offsets, + JavaDStream test1 = ssc.kafkaStream("localhost", 12345, "group", topics); + JavaDStream test2 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets); + JavaDStream test3 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets, StorageLevel.MEMORY_AND_DISK()); } @Test public void testNetworkTextStream() { - JavaDStream test = sc.networkTextStream("localhost", 12345); + JavaDStream test = ssc.networkTextStream("localhost", 12345); } @Test @@ -973,7 +997,7 @@ public class JavaAPISuite implements Serializable { } } - JavaDStream test = sc.networkStream( + JavaDStream test = ssc.networkStream( "localhost", 12345, new Converter(), @@ -982,22 +1006,22 @@ public class JavaAPISuite implements Serializable { @Test public void testTextFileStream() { - JavaDStream test = sc.textFileStream("/tmp/foo"); + JavaDStream test = ssc.textFileStream("/tmp/foo"); } @Test public void testRawNetworkStream() { - JavaDStream test = sc.rawNetworkStream("localhost", 12345); + JavaDStream test = ssc.rawNetworkStream("localhost", 12345); } @Test public void testFlumeStream() { - JavaDStream test = sc.flumeStream("localhost", 12345); + JavaDStream test = ssc.flumeStream("localhost", 12345); } @Test public void testFileStream() { JavaPairDStream foo = - sc.fileStream("/tmp/foo"); + ssc.fileStream("/tmp/foo"); } } -- cgit v1.2.3 From c46dd2de78ae0c13060d0a9d2dea110c655659f0 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 17 Jan 2013 21:43:17 -0800 Subject: Moving tests to appropriate directory --- streaming/src/test/java/JavaAPISuite.java | 1027 ++++++++++++++++++++ streaming/src/test/java/JavaTestUtils.scala | 65 ++ streaming/src/test/scala/JavaTestUtils.scala | 65 -- .../test/scala/spark/streaming/JavaAPISuite.java | 1027 -------------------- 4 files changed, 1092 insertions(+), 1092 deletions(-) create mode 100644 streaming/src/test/java/JavaAPISuite.java create mode 100644 streaming/src/test/java/JavaTestUtils.scala delete mode 100644 streaming/src/test/scala/JavaTestUtils.scala delete mode 100644 streaming/src/test/scala/spark/streaming/JavaAPISuite.java diff --git a/streaming/src/test/java/JavaAPISuite.java b/streaming/src/test/java/JavaAPISuite.java new file mode 100644 index 0000000000..8c94e13e65 --- /dev/null +++ b/streaming/src/test/java/JavaAPISuite.java @@ -0,0 +1,1027 @@ +package spark.streaming; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.io.Files; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import scala.Tuple2; +import spark.HashPartitioner; +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; +import spark.api.java.function.*; +import spark.storage.StorageLevel; +import spark.streaming.api.java.JavaDStream; +import spark.streaming.api.java.JavaPairDStream; +import spark.streaming.api.java.JavaStreamingContext; +import spark.streaming.JavaTestUtils; +import spark.streaming.JavaCheckpointTestUtils; +import spark.streaming.dstream.KafkaPartitionKey; + +import java.io.*; +import java.util.*; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaStreamingContext ssc; + + @Before + public void setUp() { + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port"); + } + + @Test + public void testCount() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3,4), + Arrays.asList(3,4,5), + Arrays.asList(3)); + + List> expected = Arrays.asList( + Arrays.asList(4L), + Arrays.asList(3L), + Arrays.asList(1L)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream count = stream.count(); + JavaTestUtils.attachTestOutputStream(count); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testMap() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("goodnight", "moon")); + + List> expected = Arrays.asList( + Arrays.asList(5,5), + Arrays.asList(9,4)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaTestUtils.attachTestOutputStream(letterCount); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6,1,2,3), + Arrays.asList(7,8,9,4,5,6), + Arrays.asList(7,8,9)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream windowed = stream.window(new Duration(2000)); + JavaTestUtils.attachTestOutputStream(windowed); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testWindowWithSlideDuration() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9), + Arrays.asList(10,11,12), + Arrays.asList(13,14,15), + Arrays.asList(16,17,18)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3,4,5,6), + Arrays.asList(1,2,3,4,5,6,7,8,9,10,11,12), + Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18), + Arrays.asList(13,14,15,16,17,18)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000)); + JavaTestUtils.attachTestOutputStream(windowed); + List> result = JavaTestUtils.runStreams(ssc, 8, 4); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testTumble() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9), + Arrays.asList(10,11,12), + Arrays.asList(13,14,15), + Arrays.asList(16,17,18)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3,4,5,6), + Arrays.asList(7,8,9,10,11,12), + Arrays.asList(13,14,15,16,17,18)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream windowed = stream.tumble(new Duration(2000)); + JavaTestUtils.attachTestOutputStream(windowed); + List> result = JavaTestUtils.runStreams(ssc, 6, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("yankees")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream filtered = stream.filter(new Function() { + @Override + public Boolean call(String s) throws Exception { + return s.contains("a"); + } + }); + JavaTestUtils.attachTestOutputStream(filtered); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testGlom() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List>> expected = Arrays.asList( + Arrays.asList(Arrays.asList("giants", "dodgers")), + Arrays.asList(Arrays.asList("yankees", "red socks"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream glommed = stream.glom(); + JavaTestUtils.attachTestOutputStream(glommed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testMapPartitions() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("GIANTSDODGERS"), + Arrays.asList("YANKEESRED SOCKS")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream mapped = stream.mapPartitions(new FlatMapFunction, String>() { + @Override + public Iterable call(Iterator in) { + String out = ""; + while (in.hasNext()) { + out = out + in.next().toUpperCase(); + } + return Lists.newArrayList(out); + } + }); + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + private class IntegerSum extends Function2 { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + } + + private class IntegerDifference extends Function2 { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 - i2; + } + } + + @Test + public void testReduce() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(15), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reduced = stream.reduce(new IntegerSum()); + JavaTestUtils.attachTestOutputStream(reduced); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(21), + Arrays.asList(39), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), + new IntegerDifference(), new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reducedWindowed); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); + + Assert.assertEquals(expected, result); + } + + @Test + public void testQueueStream() { + List> expected = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); + JavaRDD rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3)); + JavaRDD rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6)); + JavaRDD rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9)); + + LinkedList> rdds = Lists.newLinkedList(); + rdds.add(rdd1); + rdds.add(rdd2); + rdds.add(rdd3); + + JavaDStream stream = ssc.queueStream(rdds); + JavaTestUtils.attachTestOutputStream(stream); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + Assert.assertEquals(expected, result); + } + + @Test + public void testTransform() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(3,4,5), + Arrays.asList(6,7,8), + Arrays.asList(9,10,11)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream transformed = stream.transform(new Function, JavaRDD>() { + @Override + public JavaRDD call(JavaRDD in) throws Exception { + return in.map(new Function() { + @Override + public Integer call(Integer i) throws Exception { + return i + 2; + } + }); + }}); + JavaTestUtils.attachTestOutputStream(transformed); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("go", "giants"), + Arrays.asList("boo", "dodgers"), + Arrays.asList("athletics")); + + List> expected = Arrays.asList( + Arrays.asList("g","o","g","i","a","n","t","s"), + Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"), + Arrays.asList("a","t","h","l","e","t","i","c","s")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(x.split("(?!^)")); + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testPairFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("dodgers"), + Arrays.asList("athletics")); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(6, "g"), + new Tuple2(6, "i"), + new Tuple2(6, "a"), + new Tuple2(6, "n"), + new Tuple2(6, "t"), + new Tuple2(6, "s")), + Arrays.asList( + new Tuple2(7, "d"), + new Tuple2(7, "o"), + new Tuple2(7, "d"), + new Tuple2(7, "g"), + new Tuple2(7, "e"), + new Tuple2(7, "r"), + new Tuple2(7, "s")), + Arrays.asList( + new Tuple2(9, "a"), + new Tuple2(9, "t"), + new Tuple2(9, "h"), + new Tuple2(9, "l"), + new Tuple2(9, "e"), + new Tuple2(9, "t"), + new Tuple2(9, "i"), + new Tuple2(9, "c"), + new Tuple2(9, "s"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction() { + @Override + public Iterable> call(String in) throws Exception { + List> out = Lists.newArrayList(); + for (String letter: in.split("(?!^)")) { + out.add(new Tuple2(in.length(), letter)); + } + return out; + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testUnion() { + List> inputData1 = Arrays.asList( + Arrays.asList(1,1), + Arrays.asList(2,2), + Arrays.asList(3,3)); + + List> inputData2 = Arrays.asList( + Arrays.asList(4,4), + Arrays.asList(5,5), + Arrays.asList(6,6)); + + List> expected = Arrays.asList( + Arrays.asList(1,1,4,4), + Arrays.asList(2,2,5,5), + Arrays.asList(3,3,6,6)); + + JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2); + JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2); + + JavaDStream unioned = stream1.union(stream2); + JavaTestUtils.attachTestOutputStream(unioned); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + /* + * Performs an order-invariant comparison of lists representing two RDD streams. This allows + * us to account for ordering variation within individual RDD's which occurs during windowing. + */ + public static void assertOrderInvariantEquals( + List> expected, List> actual) { + for (List list: expected) { + Collections.sort(list); + } + for (List list: actual) { + Collections.sort(list); + } + Assert.assertEquals(expected, actual); + } + + + // PairDStream Functions + @Test + public void testPairFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("giants", 6)), + Arrays.asList(new Tuple2("yankees", 7))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = stream.map( + new PairFunction() { + @Override + public Tuple2 call(String in) throws Exception { + return new Tuple2(in, in.length()); + } + }); + + JavaPairDStream filtered = pairStream.filter( + new Function, Boolean>() { + @Override + public Boolean call(Tuple2 in) throws Exception { + return in._1().contains("a"); + } + }); + JavaTestUtils.attachTestOutputStream(filtered); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("california", "giants"), + new Tuple2("new york", "yankees"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("california", "ducks"), + new Tuple2("new york", "rangers"), + new Tuple2("new york", "islanders"))); + + List>> stringIntKVStream = Arrays.asList( + Arrays.asList( + new Tuple2("california", 1), + new Tuple2("california", 3), + new Tuple2("new york", 4), + new Tuple2("new york", 1)), + Arrays.asList( + new Tuple2("california", 5), + new Tuple2("california", 5), + new Tuple2("new york", 3), + new Tuple2("new york", 1))); + + @Test + public void testPairGroupByKey() { + List>> inputData = stringStringKVStream; + + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", Arrays.asList("dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + Arrays.asList( + new Tuple2>("california", Arrays.asList("sharks", "ducks")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream> grouped = pairStream.groupByKey(); + JavaTestUtils.attachTestOutputStream(grouped); + List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairReduceByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList( + new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduced = pairStream.reduceByKey(new IntegerSum()); + + JavaTestUtils.attachTestOutputStream(reduced); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCombineByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList( + new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream combined = pairStream.combineByKey( + new Function() { + @Override + public Integer call(Integer i) throws Exception { + return i; + } + }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); + + JavaTestUtils.attachTestOutputStream(combined); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCountByKey() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L)), + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream counted = pairStream.countByKey(); + JavaTestUtils.attachTestOutputStream(counted); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testGroupByKeyAndWindow() { + List>> inputData = stringStringKVStream; + + List>>> expected = Arrays.asList( + Arrays.asList(new Tuple2>("california", Arrays.asList("dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + Arrays.asList(new Tuple2>("california", + Arrays.asList("sharks", "ducks", "dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders", "yankees", "mets"))), + Arrays.asList(new Tuple2>("california", Arrays.asList("sharks", "ducks")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream> groupWindowed = + pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(groupWindowed); + List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindow() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testUpdateStateByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream updated = pairStream.updateStateByKey( + new Function2, Optional, Optional>(){ + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; + } + return Optional.of(out); + } + }); + JavaTestUtils.attachTestOutputStream(updated); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindowWithInverse() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCountByKeyAndWindow() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L)), + Arrays.asList( + new Tuple2("california", 4L), + new Tuple2("new york", 4L)), + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream counted = + pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(counted); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "DODGERS"), + new Tuple2("california", "GIANTS"), + new Tuple2("new york", "YANKEES"), + new Tuple2("new york", "METS")), + Arrays.asList(new Tuple2("california", "SHARKS"), + new Tuple2("california", "DUCKS"), + new Tuple2("new york", "RANGERS"), + new Tuple2("new york", "ISLANDERS"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream mapped = pairStream.mapValues(new Function() { + @Override + public String call(String s) throws Exception { + return s.toUpperCase(); + } + }); + + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testFlatMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers1"), + new Tuple2("california", "dodgers2"), + new Tuple2("california", "giants1"), + new Tuple2("california", "giants2"), + new Tuple2("new york", "yankees1"), + new Tuple2("new york", "yankees2"), + new Tuple2("new york", "mets1"), + new Tuple2("new york", "mets2")), + Arrays.asList(new Tuple2("california", "sharks1"), + new Tuple2("california", "sharks2"), + new Tuple2("california", "ducks1"), + new Tuple2("california", "ducks2"), + new Tuple2("new york", "rangers1"), + new Tuple2("new york", "rangers2"), + new Tuple2("new york", "islanders1"), + new Tuple2("new york", "islanders2"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + + JavaPairDStream flatMapped = pairStream.flatMapValues( + new Function>() { + @Override + public Iterable call(String in) { + List out = new ArrayList(); + out.add(in + "1"); + out.add(in + "2"); + return out; + } + }); + + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCoGroup() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("new york", "yankees")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList(new Tuple2("california", "giants"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "ducks"), + new Tuple2("new york", "islanders"))); + + + List, List>>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2, List>>("california", + new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2, List>>("new york", + new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + Arrays.asList( + new Tuple2, List>>("california", + new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2, List>>("new york", + new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream, List>> grouped = pairStream1.cogroup(pairStream2); + JavaTestUtils.attachTestOutputStream(grouped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testJoin() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("new york", "yankees")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList(new Tuple2("california", "giants"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "ducks"), + new Tuple2("new york", "islanders"))); + + + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", + new Tuple2("dodgers", "giants")), + new Tuple2>("new york", + new Tuple2("yankees", "mets"))), + Arrays.asList( + new Tuple2>("california", + new Tuple2("sharks", "ducks")), + new Tuple2>("new york", + new Tuple2("rangers", "islanders")))); + + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream> joined = pairStream1.join(pairStream2); + JavaTestUtils.attachTestOutputStream(joined); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCheckpointMasterRecovery() throws InterruptedException { + List> inputData = Arrays.asList( + Arrays.asList("this", "is"), + Arrays.asList("a", "test"), + Arrays.asList("counting", "letters")); + + List> expectedInitial = Arrays.asList( + Arrays.asList(4,2)); + List> expectedFinal = Arrays.asList( + Arrays.asList(1,4), + Arrays.asList(8,7)); + + + File tempDir = Files.createTempDir(); + ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); + + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaCheckpointTestUtils.attachTestOutputStream(letterCount); + List> initialResult = JavaTestUtils.runStreams(ssc, 1, 1); + + assertOrderInvariantEquals(expectedInitial, initialResult); + Thread.sleep(1000); + + ssc.stop(); + ssc = new JavaStreamingContext(tempDir.getAbsolutePath()); + ssc.start(); + List> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 2); + assertOrderInvariantEquals(expectedFinal, finalResult); + } + + /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD + @Test + public void testCheckpointofIndividualStream() throws InterruptedException { + List> inputData = Arrays.asList( + Arrays.asList("this", "is"), + Arrays.asList("a", "test"), + Arrays.asList("counting", "letters")); + + List> expected = Arrays.asList( + Arrays.asList(4,2), + Arrays.asList(1,4), + Arrays.asList(8,7)); + + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaCheckpointTestUtils.attachTestOutputStream(letterCount); + + letterCount.checkpoint(new Duration(1000)); + + List> result1 = JavaCheckpointTestUtils.runStreams(ssc, 3, 3); + assertOrderInvariantEquals(expected, result1); + } + */ + + // Input stream tests. These mostly just test that we can instantiate a given InputStream with + // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the + // InputStream functionality is deferred to the existing Scala tests. + @Test + public void testKafkaStream() { + HashMap topics = Maps.newHashMap(); + HashMap offsets = Maps.newHashMap(); + JavaDStream test1 = ssc.kafkaStream("localhost", 12345, "group", topics); + JavaDStream test2 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets); + JavaDStream test3 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets, + StorageLevel.MEMORY_AND_DISK()); + } + + @Test + public void testNetworkTextStream() { + JavaDStream test = ssc.networkTextStream("localhost", 12345); + } + + @Test + public void testNetworkString() { + class Converter extends Function> { + public Iterable call(InputStream in) { + BufferedReader reader = new BufferedReader(new InputStreamReader(in)); + List out = new ArrayList(); + try { + while (true) { + String line = reader.readLine(); + if (line == null) { break; } + out.add(line); + } + } catch (IOException e) { } + return out; + } + } + + JavaDStream test = ssc.networkStream( + "localhost", + 12345, + new Converter(), + StorageLevel.MEMORY_ONLY()); + } + + @Test + public void testTextFileStream() { + JavaDStream test = ssc.textFileStream("/tmp/foo"); + } + + @Test + public void testRawNetworkStream() { + JavaDStream test = ssc.rawNetworkStream("localhost", 12345); + } + + @Test + public void testFlumeStream() { + JavaDStream test = ssc.flumeStream("localhost", 12345); + } + + @Test + public void testFileStream() { + JavaPairDStream foo = + ssc.fileStream("/tmp/foo"); + } +} diff --git a/streaming/src/test/java/JavaTestUtils.scala b/streaming/src/test/java/JavaTestUtils.scala new file mode 100644 index 0000000000..56349837e5 --- /dev/null +++ b/streaming/src/test/java/JavaTestUtils.scala @@ -0,0 +1,65 @@ +package spark.streaming + +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.{List => JList} +import spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} +import spark.streaming._ +import java.util.ArrayList +import collection.JavaConversions._ + +/** Exposes streaming test functionality in a Java-friendly way. */ +trait JavaTestBase extends TestSuiteBase { + + /** + * Create a [[spark.streaming.TestInputStream]] and attach it to the supplied context. + * The stream will be derived from the supplied lists of Java objects. + **/ + def attachTestInputStream[T]( + ssc: JavaStreamingContext, + data: JList[JList[T]], + numPartitions: Int) = { + val seqData = data.map(Seq(_:_*)) + + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val dstream = new TestInputStream[T](ssc.ssc, seqData, numPartitions) + ssc.ssc.registerInputStream(dstream) + new JavaDStream[T](dstream) + } + + /** + * Attach a provided stream to it's associated StreamingContext as a + * [[spark.streaming.TestOutputStream]]. + **/ + def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]]( + dstream: JavaDStreamLike[T, This]) = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val ostream = new TestOutputStream(dstream.dstream, + new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]) + dstream.dstream.ssc.registerOutputStream(ostream) + } + + /** + * Process all registered streams for a numBatches batches, failing if + * numExpectedOutput RDD's are not generated. Generated RDD's are collected + * and returned, represented as a list for each batch interval. + */ + def runStreams[V]( + ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { + implicit val cm: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) + val out = new ArrayList[JList[V]]() + res.map(entry => out.append(new ArrayList[V](entry))) + out + } +} + +object JavaTestUtils extends JavaTestBase { + +} + +object JavaCheckpointTestUtils extends JavaTestBase { + override def actuallyWait = true +} \ No newline at end of file diff --git a/streaming/src/test/scala/JavaTestUtils.scala b/streaming/src/test/scala/JavaTestUtils.scala deleted file mode 100644 index 56349837e5..0000000000 --- a/streaming/src/test/scala/JavaTestUtils.scala +++ /dev/null @@ -1,65 +0,0 @@ -package spark.streaming - -import collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import java.util.{List => JList} -import spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} -import spark.streaming._ -import java.util.ArrayList -import collection.JavaConversions._ - -/** Exposes streaming test functionality in a Java-friendly way. */ -trait JavaTestBase extends TestSuiteBase { - - /** - * Create a [[spark.streaming.TestInputStream]] and attach it to the supplied context. - * The stream will be derived from the supplied lists of Java objects. - **/ - def attachTestInputStream[T]( - ssc: JavaStreamingContext, - data: JList[JList[T]], - numPartitions: Int) = { - val seqData = data.map(Seq(_:_*)) - - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val dstream = new TestInputStream[T](ssc.ssc, seqData, numPartitions) - ssc.ssc.registerInputStream(dstream) - new JavaDStream[T](dstream) - } - - /** - * Attach a provided stream to it's associated StreamingContext as a - * [[spark.streaming.TestOutputStream]]. - **/ - def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]]( - dstream: JavaDStreamLike[T, This]) = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val ostream = new TestOutputStream(dstream.dstream, - new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]) - dstream.dstream.ssc.registerOutputStream(ostream) - } - - /** - * Process all registered streams for a numBatches batches, failing if - * numExpectedOutput RDD's are not generated. Generated RDD's are collected - * and returned, represented as a list for each batch interval. - */ - def runStreams[V]( - ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { - implicit val cm: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] - val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[V]]() - res.map(entry => out.append(new ArrayList[V](entry))) - out - } -} - -object JavaTestUtils extends JavaTestBase { - -} - -object JavaCheckpointTestUtils extends JavaTestBase { - override def actuallyWait = true -} \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java deleted file mode 100644 index 8c94e13e65..0000000000 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ /dev/null @@ -1,1027 +0,0 @@ -package spark.streaming; - -import com.google.common.base.Optional; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; -import com.google.common.io.Files; -import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import scala.Tuple2; -import spark.HashPartitioner; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.*; -import spark.storage.StorageLevel; -import spark.streaming.api.java.JavaDStream; -import spark.streaming.api.java.JavaPairDStream; -import spark.streaming.api.java.JavaStreamingContext; -import spark.streaming.JavaTestUtils; -import spark.streaming.JavaCheckpointTestUtils; -import spark.streaming.dstream.KafkaPartitionKey; - -import java.io.*; -import java.util.*; - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { - private transient JavaStreamingContext ssc; - - @Before - public void setUp() { - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port"); - } - - @Test - public void testCount() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3,4), - Arrays.asList(3,4,5), - Arrays.asList(3)); - - List> expected = Arrays.asList( - Arrays.asList(4L), - Arrays.asList(3L), - Arrays.asList(1L)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream count = stream.count(); - JavaTestUtils.attachTestOutputStream(count); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testMap() { - List> inputData = Arrays.asList( - Arrays.asList("hello", "world"), - Arrays.asList("goodnight", "moon")); - - List> expected = Arrays.asList( - Arrays.asList(5,5), - Arrays.asList(9,4)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }); - JavaTestUtils.attachTestOutputStream(letterCount); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testWindow() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6,1,2,3), - Arrays.asList(7,8,9,4,5,6), - Arrays.asList(7,8,9)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.window(new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(ssc, 4, 4); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testWindowWithSlideDuration() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9), - Arrays.asList(10,11,12), - Arrays.asList(13,14,15), - Arrays.asList(16,17,18)); - - List> expected = Arrays.asList( - Arrays.asList(1,2,3,4,5,6), - Arrays.asList(1,2,3,4,5,6,7,8,9,10,11,12), - Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18), - Arrays.asList(13,14,15,16,17,18)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(ssc, 8, 4); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testTumble() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9), - Arrays.asList(10,11,12), - Arrays.asList(13,14,15), - Arrays.asList(16,17,18)); - - List> expected = Arrays.asList( - Arrays.asList(1,2,3,4,5,6), - Arrays.asList(7,8,9,10,11,12), - Arrays.asList(13,14,15,16,17,18)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.tumble(new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(ssc, 6, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testFilter() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List> expected = Arrays.asList( - Arrays.asList("giants"), - Arrays.asList("yankees")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream filtered = stream.filter(new Function() { - @Override - public Boolean call(String s) throws Exception { - return s.contains("a"); - } - }); - JavaTestUtils.attachTestOutputStream(filtered); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testGlom() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List>> expected = Arrays.asList( - Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream glommed = stream.glom(); - JavaTestUtils.attachTestOutputStream(glommed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testMapPartitions() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List> expected = Arrays.asList( - Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream mapped = stream.mapPartitions(new FlatMapFunction, String>() { - @Override - public Iterable call(Iterator in) { - String out = ""; - while (in.hasNext()) { - out = out + in.next().toUpperCase(); - } - return Lists.newArrayList(out); - } - }); - JavaTestUtils.attachTestOutputStream(mapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - private class IntegerSum extends Function2 { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; - } - } - - private class IntegerDifference extends Function2 { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 - i2; - } - } - - @Test - public void testReduce() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(6), - Arrays.asList(15), - Arrays.asList(24)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reduced = stream.reduce(new IntegerSum()); - JavaTestUtils.attachTestOutputStream(reduced); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByWindow() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(6), - Arrays.asList(21), - Arrays.asList(39), - Arrays.asList(24)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reducedWindowed); - List> result = JavaTestUtils.runStreams(ssc, 4, 4); - - Assert.assertEquals(expected, result); - } - - @Test - public void testQueueStream() { - List> expected = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3)); - JavaRDD rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6)); - JavaRDD rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9)); - - LinkedList> rdds = Lists.newLinkedList(); - rdds.add(rdd1); - rdds.add(rdd2); - rdds.add(rdd3); - - JavaDStream stream = ssc.queueStream(rdds); - JavaTestUtils.attachTestOutputStream(stream); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - Assert.assertEquals(expected, result); - } - - @Test - public void testTransform() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(3,4,5), - Arrays.asList(6,7,8), - Arrays.asList(9,10,11)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream transformed = stream.transform(new Function, JavaRDD>() { - @Override - public JavaRDD call(JavaRDD in) throws Exception { - return in.map(new Function() { - @Override - public Integer call(Integer i) throws Exception { - return i + 2; - } - }); - }}); - JavaTestUtils.attachTestOutputStream(transformed); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testFlatMap() { - List> inputData = Arrays.asList( - Arrays.asList("go", "giants"), - Arrays.asList("boo", "dodgers"), - Arrays.asList("athletics")); - - List> expected = Arrays.asList( - Arrays.asList("g","o","g","i","a","n","t","s"), - Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"), - Arrays.asList("a","t","h","l","e","t","i","c","s")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { - @Override - public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); - } - }); - JavaTestUtils.attachTestOutputStream(flatMapped); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testPairFlatMap() { - List> inputData = Arrays.asList( - Arrays.asList("giants"), - Arrays.asList("dodgers"), - Arrays.asList("athletics")); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), - Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), - Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction() { - @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); - for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); - } - return out; - } - }); - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testUnion() { - List> inputData1 = Arrays.asList( - Arrays.asList(1,1), - Arrays.asList(2,2), - Arrays.asList(3,3)); - - List> inputData2 = Arrays.asList( - Arrays.asList(4,4), - Arrays.asList(5,5), - Arrays.asList(6,6)); - - List> expected = Arrays.asList( - Arrays.asList(1,1,4,4), - Arrays.asList(2,2,5,5), - Arrays.asList(3,3,6,6)); - - JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2); - JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2); - - JavaDStream unioned = stream1.union(stream2); - JavaTestUtils.attachTestOutputStream(unioned); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - /* - * Performs an order-invariant comparison of lists representing two RDD streams. This allows - * us to account for ordering variation within individual RDD's which occurs during windowing. - */ - public static void assertOrderInvariantEquals( - List> expected, List> actual) { - for (List list: expected) { - Collections.sort(list); - } - for (List list: actual) { - Collections.sort(list); - } - Assert.assertEquals(expected, actual); - } - - - // PairDStream Functions - @Test - public void testPairFilter() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = stream.map( - new PairFunction() { - @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); - } - }); - - JavaPairDStream filtered = pairStream.filter( - new Function, Boolean>() { - @Override - public Boolean call(Tuple2 in) throws Exception { - return in._1().contains("a"); - } - }); - JavaTestUtils.attachTestOutputStream(filtered); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); - - List>> stringIntKVStream = Arrays.asList( - Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), - Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); - - @Test - public void testPairGroupByKey() { - List>> inputData = stringStringKVStream; - - List>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), - Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream> grouped = pairStream.groupByKey(); - JavaTestUtils.attachTestOutputStream(grouped); - List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairReduceByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduced = pairStream.reduceByKey(new IntegerSum()); - - JavaTestUtils.attachTestOutputStream(reduced); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCombineByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream combined = pairStream.combineByKey( - new Function() { - @Override - public Integer call(Integer i) throws Exception { - return i; - } - }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); - - JavaTestUtils.attachTestOutputStream(combined); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCountByKey() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L)), - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream counted = pairStream.countByKey(); - JavaTestUtils.attachTestOutputStream(counted); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testGroupByKeyAndWindow() { - List>> inputData = stringStringKVStream; - - List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), - Arrays.asList(new Tuple2>("california", - Arrays.asList("sharks", "ducks", "dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders", "yankees", "mets"))), - Arrays.asList(new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream> groupWindowed = - pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(groupWindowed); - List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByKeyAndWindow() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testUpdateStateByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream updated = pairStream.updateStateByKey( - new Function2, Optional, Optional>(){ - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; - } - return Optional.of(out); - } - }); - JavaTestUtils.attachTestOutputStream(updated); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByKeyAndWindowWithInverse() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCountByKeyAndWindow() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L)), - Arrays.asList( - new Tuple2("california", 4L), - new Tuple2("new york", 4L)), - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream counted = - pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(counted); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testMapValues() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream mapped = pairStream.mapValues(new Function() { - @Override - public String call(String s) throws Exception { - return s.toUpperCase(); - } - }); - - JavaTestUtils.attachTestOutputStream(mapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testFlatMapValues() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - - JavaPairDStream flatMapped = pairStream.flatMapValues( - new Function>() { - @Override - public Iterable call(String in) { - List out = new ArrayList(); - out.add(in + "1"); - out.add(in + "2"); - return out; - } - }); - - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCoGroup() { - List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); - - List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); - - - List, List>>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), - Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); - - - JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream1, 1); - JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); - - JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream2, 1); - JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - - JavaPairDStream, List>> grouped = pairStream1.cogroup(pairStream2); - JavaTestUtils.attachTestOutputStream(grouped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testJoin() { - List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); - - List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); - - - List>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), - Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); - - - JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream1, 1); - JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); - - JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream2, 1); - JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - - JavaPairDStream> joined = pairStream1.join(pairStream2); - JavaTestUtils.attachTestOutputStream(joined); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCheckpointMasterRecovery() throws InterruptedException { - List> inputData = Arrays.asList( - Arrays.asList("this", "is"), - Arrays.asList("a", "test"), - Arrays.asList("counting", "letters")); - - List> expectedInitial = Arrays.asList( - Arrays.asList(4,2)); - List> expectedFinal = Arrays.asList( - Arrays.asList(1,4), - Arrays.asList(8,7)); - - - File tempDir = Files.createTempDir(); - ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); - - JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }); - JavaCheckpointTestUtils.attachTestOutputStream(letterCount); - List> initialResult = JavaTestUtils.runStreams(ssc, 1, 1); - - assertOrderInvariantEquals(expectedInitial, initialResult); - Thread.sleep(1000); - - ssc.stop(); - ssc = new JavaStreamingContext(tempDir.getAbsolutePath()); - ssc.start(); - List> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 2); - assertOrderInvariantEquals(expectedFinal, finalResult); - } - - /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD - @Test - public void testCheckpointofIndividualStream() throws InterruptedException { - List> inputData = Arrays.asList( - Arrays.asList("this", "is"), - Arrays.asList("a", "test"), - Arrays.asList("counting", "letters")); - - List> expected = Arrays.asList( - Arrays.asList(4,2), - Arrays.asList(1,4), - Arrays.asList(8,7)); - - JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }); - JavaCheckpointTestUtils.attachTestOutputStream(letterCount); - - letterCount.checkpoint(new Duration(1000)); - - List> result1 = JavaCheckpointTestUtils.runStreams(ssc, 3, 3); - assertOrderInvariantEquals(expected, result1); - } - */ - - // Input stream tests. These mostly just test that we can instantiate a given InputStream with - // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the - // InputStream functionality is deferred to the existing Scala tests. - @Test - public void testKafkaStream() { - HashMap topics = Maps.newHashMap(); - HashMap offsets = Maps.newHashMap(); - JavaDStream test1 = ssc.kafkaStream("localhost", 12345, "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets); - JavaDStream test3 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets, - StorageLevel.MEMORY_AND_DISK()); - } - - @Test - public void testNetworkTextStream() { - JavaDStream test = ssc.networkTextStream("localhost", 12345); - } - - @Test - public void testNetworkString() { - class Converter extends Function> { - public Iterable call(InputStream in) { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - try { - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - } catch (IOException e) { } - return out; - } - } - - JavaDStream test = ssc.networkStream( - "localhost", - 12345, - new Converter(), - StorageLevel.MEMORY_ONLY()); - } - - @Test - public void testTextFileStream() { - JavaDStream test = ssc.textFileStream("/tmp/foo"); - } - - @Test - public void testRawNetworkStream() { - JavaDStream test = ssc.rawNetworkStream("localhost", 12345); - } - - @Test - public void testFlumeStream() { - JavaDStream test = ssc.flumeStream("localhost", 12345); - } - - @Test - public void testFileStream() { - JavaPairDStream foo = - ssc.fileStream("/tmp/foo"); - } -} -- cgit v1.2.3 From 12b72b3e73798a5a2cc6c745610e135b1d6825a6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 17 Jan 2013 22:37:56 -0800 Subject: NetworkWordCount example --- .../streaming/examples/JavaNetworkWordCount.java | 62 ++++++++++++++++++++++ .../streaming/examples/NetworkWordCount.scala | 2 +- 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java diff --git a/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java new file mode 100644 index 0000000000..4299febfd6 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java @@ -0,0 +1,62 @@ +package spark.streaming.examples; + +import com.google.common.collect.Lists; +import scala.Tuple2; +import spark.api.java.function.FlatMapFunction; +import spark.api.java.function.Function2; +import spark.api.java.function.PairFunction; +import spark.streaming.Duration; +import spark.streaming.api.java.JavaDStream; +import spark.streaming.api.java.JavaPairDStream; +import spark.streaming.api.java.JavaStreamingContext; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: NetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999` + */ +public class JavaNetworkWordCount { + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: NetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1"); + System.exit(1); + } + + // Create the context with a 1 second batch size + JavaStreamingContext ssc = new JavaStreamingContext( + args[0], "NetworkWordCount", new Duration(1000)); + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + JavaDStream lines = ssc.networkTextStream(args[1], Integer.parseInt(args[2])); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(x.split(" ")); + } + }); + JavaPairDStream wordCounts = words.map( + new PairFunction() { + @Override + public Tuple2 call(String s) throws Exception { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }); + + wordCounts.print(); + ssc.start(); + + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala index 43c01d5db2..32f7d57bea 100644 --- a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala @@ -22,7 +22,7 @@ object NetworkWordCount { System.exit(1) } - // Create the context and set the batch size + // Create the context with a 1 second batch size val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1)) // Create a NetworkInputDStream on target ip:port and count the -- cgit v1.2.3 From ecdff861f7993251163b82e737aba6bb1bb814d8 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 19 Jan 2013 22:59:35 -0800 Subject: Clarifying log directory in EC2 guide --- docs/ec2-scripts.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 6e1f7fd3b1..8b069ca9ad 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -96,7 +96,8 @@ permissions on your private key file, you can run `launch` with the `spark-ec2` to attach a persistent EBS volume to each node for storing the persistent HDFS. - Finally, if you get errors while running your jobs, look at the slave's logs - for that job using the Mesos web UI (`http://:8080`). + for that job inside of the Mesos work directory (/mnt/mesos-work). Mesos errors + can be found using the Mesos web UI (`http://:8080`). # Configuration -- cgit v1.2.3 From 214345ceace634ec9cc83c4c85b233b699e0d219 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 19 Jan 2013 23:50:17 -0800 Subject: Fixed issue https://spark-project.atlassian.net/browse/STREAMING-29, along with updates to doc comments in SparkContext.checkpoint(). --- core/src/main/scala/spark/RDD.scala | 17 ++++++++--------- core/src/main/scala/spark/RDDCheckpointData.scala | 2 +- core/src/main/scala/spark/SparkContext.scala | 13 +++++++------ streaming/src/main/scala/spark/streaming/DStream.scala | 8 +++++++- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index a9f2e86455..e0d2eabb1d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -549,17 +549,16 @@ abstract class RDD[T: ClassManifest]( } /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() { - if (checkpointData.isEmpty) { + if (context.checkpointDir.isEmpty) { + throw new Exception("Checkpoint directory has not been set in the SparkContext") + } else if (checkpointData.isEmpty) { checkpointData = Some(new RDDCheckpointData(this)) checkpointData.get.markForCheckpoint() } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index d845a522e4..18df530b7d 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -63,7 +63,7 @@ extends Logging with Serializable { } // Save to file, and reload it as an RDD - val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) val newRDD = new CheckpointRDD[T](rdd.context, path) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 88cf357ebf..7f3259d982 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -184,7 +184,7 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) - private[spark] var checkpointDir: String = null + private[spark] var checkpointDir: Option[String] = None // Methods for creating RDDs @@ -595,10 +595,11 @@ class SparkContext( } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) @@ -610,7 +611,7 @@ class SparkContext( fs.mkdirs(path) } } - checkpointDir = dir + checkpointDir = Some(dir) } /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index fbe3cebd6d..c4442b6a0c 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -154,10 +154,16 @@ abstract class DStream[T: ClassManifest] ( assert( !mustCheckpoint || checkpointDuration != null, - "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + " Please use DStream.checkpoint() to set the interval." ) + assert( + checkpointDuration == null || ssc.sc.checkpointDir.isDefined, + "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + + " or SparkContext.checkpoint() to set the checkpoint directory." + ) + assert( checkpointDuration == null || checkpointDuration >= slideDuration, "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + -- cgit v1.2.3 From 8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 01:57:44 -0800 Subject: Added accumulators to PySpark --- .../main/scala/spark/api/python/PythonRDD.scala | 83 +++++++++-- python/pyspark/__init__.py | 4 + python/pyspark/accumulators.py | 166 +++++++++++++++++++++ python/pyspark/context.py | 38 +++++ python/pyspark/rdd.py | 2 +- python/pyspark/serializers.py | 7 +- python/pyspark/shell.py | 4 +- python/pyspark/worker.py | 7 +- 8 files changed, 290 insertions(+), 21 deletions(-) create mode 100644 python/pyspark/accumulators.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f431ef28d3..fb13e84658 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,7 +1,8 @@ package spark.api.python import java.io._ -import java.util.{List => JList} +import java.net._ +import java.util.{List => JList, ArrayList => JArrayList, Collections} import scala.collection.JavaConversions._ import scala.io.Source @@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD -import java.util private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], - command: Seq[String], - envVars: java.util.Map[String, String], - preservePartitoning: Boolean, - pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) { // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) = this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, - broadcastVars) + broadcastVars, accumulator) override def splits = parent.splits @@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest]( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(proc.getInputStream) return new Iterator[Array[Byte]] { - def next() = { + def next(): Array[Byte] = { val obj = _nextObj _nextObj = read() obj } - private def read() = { + private def read(): Array[Byte] = { try { val length = stream.readInt() - val obj = new Array[Byte](length) - stream.readFully(obj) - obj + if (length != -1) { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } else { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } } catch { case eof: EOFException => { val exitStatus = proc.waitFor() @@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte] private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } + +/** + * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * collects a list of pickled strings that we pass to Python through a socket. + */ +class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) + extends AccumulatorParam[JList[Array[Byte]]] { + + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + + override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) + : JList[Array[Byte]] = { + if (serverHost == null) { + // This happens on the worker node, where we just want to remember all the updates + val1.addAll(val2) + val1 + } else { + // This happens on the master, where we pass the updates to Python through a socket + val socket = new Socket(serverHost, serverPort) + val in = socket.getInputStream + val out = new DataOutputStream(socket.getOutputStream) + out.writeInt(val2.size) + for (array <- val2) { + out.writeInt(array.length) + out.write(array) + } + out.flush() + // Wait for a byte from the Python side as an acknowledgement + val byteRead = in.read() + if (byteRead == -1) { + throw new SparkException("EOF reached before Python server acknowledged") + } + socket.close() + null + } + } +} diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c595ae0842..00666bc0a3 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -7,6 +7,10 @@ Public classes: Main entry point for Spark functionality. - L{RDD} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + - L{Broadcast} + A broadcast variable that gets reused across tasks. + - L{Accumulator} + An "add-only" shared variable that tasks can only add values to. """ import sys import os diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py new file mode 100644 index 0000000000..438af4cfc0 --- /dev/null +++ b/python/pyspark/accumulators.py @@ -0,0 +1,166 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> a = sc.accumulator(1) +>>> a.value +1 +>>> a.value = 2 +>>> a.value +2 +>>> a += 5 +>>> a.value +7 + +>>> rdd = sc.parallelize([1,2,3]) +>>> def f(x): +... global a +... a += x +>>> rdd.foreach(f) +>>> a.value +13 + +>>> class VectorAccumulatorParam(object): +... def zero(self, value): +... return [0.0] * len(value) +... def addInPlace(self, val1, val2): +... for i in xrange(len(val1)): +... val1[i] += val2[i] +... return val1 +>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) +>>> va.value +[1.0, 2.0, 3.0] +>>> def g(x): +... global va +... va += [x] * 3 +>>> rdd.foreach(g) +>>> va.value +[7.0, 8.0, 9.0] + +>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> def h(x): +... global a +... a.value = 7 +>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Exception:... +""" + +import struct +import SocketServer +import threading +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import read_int, read_with_length, load_pickle + + +# Holds accumulators registered on the current machine, keyed by ID. This is then used to send +# the local accumulator updates back to the driver program at the end of a task. +_accumulatorRegistry = {} + + +def _deserialize_accumulator(aid, zero_value, accum_param): + from pyspark.accumulators import _accumulatorRegistry + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum + + +class Accumulator(object): + def __init__(self, aid, value, accum_param): + """Create a new Accumulator with a given initial value and AccumulatorParam object""" + from pyspark.accumulators import _accumulatorRegistry + self.aid = aid + self.accum_param = accum_param + self._value = value + self._deserialized = False + _accumulatorRegistry[aid] = self + + def __reduce__(self): + """Custom serialization; saves the zero value from our AccumulatorParam""" + param = self.accum_param + return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) + + @property + def value(self): + """Get the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + return self._value + + @value.setter + def value(self, value): + """Sets the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + self._value = value + + def __iadd__(self, term): + """The += operator; adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + return self + + def __str__(self): + return str(self._value) + + +class AddingAccumulatorParam(object): + """ + An AccumulatorParam that uses the + operators to add values. Designed for simple types + such as integers, floats, and lists. Requires the zero value for the underlying type + as a parameter. + """ + + def __init__(self, zero_value): + self.zero_value = zero_value + + def zero(self, value): + return self.zero_value + + def addInPlace(self, value1, value2): + value1 += value2 + return value1 + + +# Singleton accumulator params for some standard types +INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) +DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) +COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) + + +class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + def handle(self): + from pyspark.accumulators import _accumulatorRegistry + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = load_pickle(read_with_length(self.rfile)) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + + +def _start_update_server(): + """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" + server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + return server + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e486f206b0..1e2f845f9c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -2,6 +2,8 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark import accumulators +from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched @@ -22,6 +24,7 @@ class SparkContext(object): _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition + _next_accum_id = 0 def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -52,6 +55,14 @@ class SparkContext(object): self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) + # Create a single Accumulator in Java that we'll send all our updates through; + # they will be passed back to us through a TCP server + self._accumulatorServer = accumulators._start_update_server() + (host, port) = self._accumulatorServer.server_address + self._javaAccumulator = self._jsc.accumulator( + self.jvm.java.util.ArrayList(), + self.jvm.PythonAccumulatorParam(host, port)) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -74,6 +85,8 @@ class SparkContext(object): def __del__(self): if self._jsc: self._jsc.stop() + if self._accumulatorServer: + self._accumulatorServer.shutdown() def stop(self): """ @@ -129,6 +142,31 @@ class SparkContext(object): return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + def accumulator(self, value, accum_param=None): + """ + Create an C{Accumulator} with the given initial value, using a given + AccumulatorParam helper object to define how to add values of the data + type if provided. Default AccumulatorParams are used for integers and + floating-point numbers if you do not provide one. For other types, the + AccumulatorParam must implement two methods: + - C{zero(value)}: provide a "zero value" for the type, compatible in + dimensions with the provided C{value} (e.g., a zero vector). + - C{addInPlace(val1, val2)}: add two values of the accumulator's data + type, returning a new value; for efficiency, can also update C{val1} + in place and return it. + """ + if accum_param == None: + if isinstance(value, int): + accum_param = accumulators.INT_ACCUMULATOR_PARAM + elif isinstance(value, float): + accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM + elif isinstance(value, complex): + accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM + else: + raise Exception("No default accumulator param for type %s" % type(value)) + SparkContext._next_accum_id += 1 + return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) + def addFile(self, path): """ Add a file to be downloaded into the working directory of this Spark diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1d36da42b0..d705f0f9e1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -703,7 +703,7 @@ class PipelinedRDD(RDD): env = MapConverter().convert(env, self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) + broadcast_vars, self.ctx._javaAccumulator, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9a5151ea00..115cf28cc2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -52,8 +52,13 @@ def read_int(stream): raise EOFError return struct.unpack("!i", length)[0] + +def write_int(value, stream): + stream.write(struct.pack("!i", value)) + + def write_with_length(obj, stream): - stream.write(struct.pack("!i", len(obj))) + write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 7e6ad3aa76..f6328c561f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -1,7 +1,7 @@ """ An interactive shell. -This fle is designed to be launched as a PYTHONSTARTUP script. +This file is designed to be launched as a PYTHONSTARTUP script. """ import os from pyspark.context import SparkContext @@ -14,4 +14,4 @@ print "Spark context avaiable as sc." # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + execfile(_pythonstartup) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3d792bbaa2..b2b9288089 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -5,9 +5,10 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import write_with_length, read_with_length, \ +from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -36,6 +37,10 @@ def main(): iterator = read_from_pickle_file(sys.stdin) for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) + # Mark the beginning of the accumulators section of the output + write_int(-1, old_stdout) + for aid, accum in _accumulatorRegistry.items(): + write_with_length(dump_pickle((aid, accum._value)), old_stdout) if __name__ == '__main__': -- cgit v1.2.3 From 61b6382a352f3e801643529198b867e13debf470 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 01:59:07 -0800 Subject: Launch accumulator tests in run-tests --- python/run-tests | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/run-tests b/python/run-tests index fcdd1e27a7..32470911f9 100755 --- a/python/run-tests +++ b/python/run-tests @@ -11,6 +11,9 @@ FAILED=$(($?||$FAILED)) $FWDIR/pyspark -m doctest pyspark/broadcast.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m doctest pyspark/accumulators.py +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs." -- cgit v1.2.3 From a23ed25f3cd6e76784f831d0ab7de7d3e193b59f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 02:10:25 -0800 Subject: Add a class comment to Accumulator --- python/pyspark/accumulators.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 438af4cfc0..c00c3a37af 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -76,6 +76,18 @@ def _deserialize_accumulator(aid, zero_value, accum_param): class Accumulator(object): + """ + A shared variable that can be accumulated, i.e., has a commutative and associative "add" + operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} + operator, but only the driver program is allowed to access its value, using C{value}. + Updates from the workers get propagated automatically to the driver program. + + While C{SparkContext} supports accumulators for primitive data types like C{int} and + C{float}, users can also define accumulators for custom types by providing a custom + C{AccumulatorParam} object with a C{zero} and C{addInPlace} method. Refer to the doctest + of this module for an example. + """ + def __init__(self, aid, value, accum_param): """Create a new Accumulator with a given initial value and AccumulatorParam object""" from pyspark.accumulators import _accumulatorRegistry -- cgit v1.2.3 From ee5a07955c222dce16d0ffb9bde7f61033763c16 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 02:11:58 -0800 Subject: Fix Python guide to say accumulators are available --- docs/python-programming-guide.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 78ef310a00..a840b9b34b 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -16,7 +16,6 @@ There are a few key differences between the Python and Scala APIs: * Python is dynamically typed, so RDDs can hold objects of different types. * PySpark does not currently support the following Spark features: - - Accumulators - Special functions on RDDs of doubles, such as `mean` and `stdev` - `lookup` - `persist` at storage levels other than `MEMORY_ONLY` -- cgit v1.2.3 From 33bad85bb9143d41bc5de2068f7e8a8c39928225 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 20 Jan 2013 03:51:11 -0800 Subject: Fixed streaming testsuite bugs --- streaming/src/test/java/JavaAPISuite.java | 2 ++ streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala | 5 +++++ streaming/src/test/scala/spark/streaming/CheckpointSuite.scala | 6 +++--- streaming/src/test/scala/spark/streaming/FailureSuite.scala | 3 +++ streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala | 3 +++ streaming/src/test/scala/spark/streaming/TestSuiteBase.scala | 6 +++--- .../src/test/scala/spark/streaming/WindowOperationsSuite.scala | 5 +++++ 7 files changed, 24 insertions(+), 6 deletions(-) diff --git a/streaming/src/test/java/JavaAPISuite.java b/streaming/src/test/java/JavaAPISuite.java index 8c94e13e65..c84e7331c7 100644 --- a/streaming/src/test/java/JavaAPISuite.java +++ b/streaming/src/test/java/JavaAPISuite.java @@ -34,12 +34,14 @@ public class JavaAPISuite implements Serializable { @Before public void setUp() { ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint", new Duration(1000)); } @After public void tearDown() { ssc.stop(); ssc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port"); } diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index f73f9b1823..bfdf32c73e 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -8,6 +8,11 @@ class BasicOperationsSuite extends TestSuiteBase { override def framework() = "BasicOperationsSuite" + after { + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + test("map") { val input = Seq(1 to 4, 5 to 8, 9 to 12) testOperation( diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 920388bba9..d2f32c189b 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -15,9 +15,11 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } after { - if (ssc != null) ssc.stop() FileUtils.deleteDirectory(new File(checkpointDir)) + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") } var ssc: StreamingContext = null @@ -26,8 +28,6 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { override def batchDuration = Milliseconds(500) - override def checkpointDir = "checkpoint" - override def checkpointInterval = batchDuration override def actuallyWait = true diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 4aa428bf64..7493ac1207 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -22,6 +22,9 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { after { FailureSuite.reset() FileUtils.deleteDirectory(new File(checkpointDir)) + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") } override def framework = "CheckpointSuite" diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index e71ba6ddc1..d7ba7a5d17 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -40,6 +40,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(testDir) testDir = null } + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") } test("network input stream") { diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index a76f61d4ad..49129f3964 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -10,7 +10,7 @@ import collection.mutable.SynchronizedBuffer import java.io.{ObjectInputStream, IOException} -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfter, FunSuite} /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -56,7 +56,7 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends FunSuite with Logging { +trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { def framework = "TestSuiteBase" @@ -64,7 +64,7 @@ trait TestSuiteBase extends FunSuite with Logging { def batchDuration = Seconds(1) - def checkpointDir = null.asInstanceOf[String] + def checkpointDir = "checkpoint" def checkpointInterval = batchDuration diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index f9ba1f20f0..0c6e928835 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -11,6 +11,11 @@ class WindowOperationsSuite extends TestSuiteBase { override def batchDuration = Seconds(1) + after { + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + val largerSlideInput = Seq( Seq(("a", 1)), Seq(("a", 2)), // 1st window from here -- cgit v1.2.3 From 5f74ead63643df83b04646c08e9bfc6b4b4a9ca9 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 20 Jan 2013 08:59:20 -0800 Subject: Changes based on Matei's comment --- docs/ec2-scripts.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 8b069ca9ad..931b7a66bd 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -96,8 +96,9 @@ permissions on your private key file, you can run `launch` with the `spark-ec2` to attach a persistent EBS volume to each node for storing the persistent HDFS. - Finally, if you get errors while running your jobs, look at the slave's logs - for that job inside of the Mesos work directory (/mnt/mesos-work). Mesos errors - can be found using the Mesos web UI (`http://:8080`). + for that job inside of the Mesos work directory (/mnt/mesos-work). You can + also view the status of the cluster using the Mesos web UI + (`http://:8080`). # Configuration -- cgit v1.2.3 From 2a8c2a67909c4878ea24ec94f203287e55dd3782 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 10:24:53 -0800 Subject: Minor formatting fixes --- examples/src/main/scala/spark/examples/SparkALS.scala | 4 ++-- python/examples/als.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index 2766ad1702..5e01885dbb 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -75,8 +75,8 @@ object SparkALS { (0 to 5).map(i => { i match { - case a if a < args.length => Option(args(a)) - case _ => Option(null) + case a if a < args.length => Some(args(a)) + case _ => None } }).toArray match { case Array(host_, m, u, f, iters, slices_) => { diff --git a/python/examples/als.py b/python/examples/als.py index 284cf0d3a2..010f80097f 100755 --- a/python/examples/als.py +++ b/python/examples/als.py @@ -68,4 +68,4 @@ if __name__ == "__main__": error = rmse(R, ms, us) print "Iteration %d:" % i - print "\nRMSE: %5.4f\n" % error \ No newline at end of file + print "\nRMSE: %5.4f\n" % error -- cgit v1.2.3 From 17035db159e191a11cd86882c97078581073deb2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 11:22:38 -0800 Subject: Add __repr__ to Accumulator; fix bug in sc.accumulator --- python/pyspark/accumulators.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index c00c3a37af..8011779ddc 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -11,6 +11,12 @@ >>> a.value 7 +>>> sc.accumulator(1.0).value +1.0 + +>>> sc.accumulator(1j).value +1j + >>> rdd = sc.parallelize([1,2,3]) >>> def f(x): ... global a @@ -124,6 +130,9 @@ class Accumulator(object): def __str__(self): return str(self._value) + def __repr__(self): + return "Accumulator" % (self.aid, self._value) + class AddingAccumulatorParam(object): """ @@ -145,7 +154,7 @@ class AddingAccumulatorParam(object): # Singleton accumulator params for some standard types INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) -DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) +FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) -- cgit v1.2.3 From 7ed1bf4b485131d58ea6728e7247b79320aca9e6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 16 Jan 2013 19:15:14 -0800 Subject: Add RDD checkpointing to Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 3 -- python/epydoc.conf | 2 +- python/pyspark/context.py | 9 +++++ python/pyspark/rdd.py | 34 ++++++++++++++++ python/pyspark/tests.py | 46 ++++++++++++++++++++++ python/run-tests | 3 ++ 6 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 python/pyspark/tests.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc..8c38262dd8 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest]( } } - override def checkpoint() { } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } - override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } diff --git a/python/epydoc.conf b/python/epydoc.conf index 91ac984ba2..45102cd9fe 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -16,4 +16,4 @@ target: docs/ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers - pyspark.java_gateway pyspark.examples pyspark.shell + pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1e2f845f9c..a438b43fdc 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -195,3 +195,12 @@ class SparkContext(object): filename = path.split("/")[-1] os.environ["PYTHONPATH"] = \ "%s:%s" % (filename, os.environ["PYTHONPATH"]) + + def setCheckpointDir(self, dirName, useExisting=False): + """ + Set the directory under which RDDs are going to be checkpointed. This + method will create this directory and will throw an exception of the + path already exists (to avoid overwriting existing files may be + overwritten). The directory will be deleted on exit if indicated. + """ + self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e1..9b676cae4a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -49,6 +49,40 @@ class RDD(object): self._jrdd.cache() return self + def checkpoint(self): + """ + Mark this RDD for checkpointing. The RDD will be saved to a file inside + `checkpointDir` (set using setCheckpointDir()) and all references to + its parent RDDs will be removed. This is used to truncate very long + lineages. In the current implementation, Spark will save this RDD to + a file (using saveAsObjectFile()) after the first job using this RDD is + done. Hence, it is strongly recommended to use checkpoint() on RDDs + when + + (i) checkpoint() is called before the any job has been executed on this + RDD. + + (ii) This RDD has been made to persist in memory. Otherwise saving it + on a file will require recomputation. + """ + self._jrdd.rdd().checkpoint() + + def isCheckpointed(self): + """ + Return whether this RDD has been checkpointed or not + """ + return self._jrdd.rdd().isCheckpointed() + + def getCheckpointFile(self): + """ + Gets the name of the file to which this RDD was checkpointed + """ + checkpointFile = self._jrdd.rdd().getCheckpointFile() + if checkpointFile.isDefined(): + return checkpointFile.get() + else: + return None + # TODO persist(self, storageLevel) def map(self, f, preservesPartitioning=False): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py new file mode 100644 index 0000000000..c959d5dec7 --- /dev/null +++ b/python/pyspark/tests.py @@ -0,0 +1,46 @@ +""" +Unit tests for PySpark; additional tests are implemented as doctests in +individual modules. +""" +import atexit +import os +import shutil +from tempfile import NamedTemporaryFile +import time +import unittest + +from pyspark.context import SparkContext + + +class TestCheckpoint(unittest.TestCase): + + def setUp(self): + self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + + def tearDown(self): + self.sc.stop() + + def test_basic_checkpointing(self): + checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(checkpointDir.name) + self.sc.setCheckpointDir(checkpointDir.name) + + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual(checkpointDir.name, + os.path.dirname(flatMappedRDD.getCheckpointFile())) + + atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/run-tests b/python/run-tests index 32470911f9..ce214e98a8 100755 --- a/python/run-tests +++ b/python/run-tests @@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED)) $FWDIR/pyspark -m doctest pyspark/accumulators.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m unittest pyspark.tests +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs." -- cgit v1.2.3 From d0ba80dc727d00b2b7627dcefd2c77009af55f7d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 13:59:45 -0800 Subject: Add checkpointFile() and more tests to PySpark. --- python/pyspark/context.py | 6 +++++- python/pyspark/rdd.py | 9 ++++++++- python/pyspark/tests.py | 24 ++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a438b43fdc..8beb8e2ae9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -123,6 +123,10 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def _checkpointFile(self, name): + jrdd = self._jsc.checkpointFile(name) + return RDD(jrdd, self) + def union(self, rdds): """ Build the union of a list of RDDs. @@ -145,7 +149,7 @@ class SparkContext(object): def accumulator(self, value, accum_param=None): """ Create an C{Accumulator} with the given initial value, using a given - AccumulatorParam helper object to define how to add values of the data + AccumulatorParam helper object to define how to add values of the data type if provided. Default AccumulatorParams are used for integers and floating-point numbers if you do not provide one. For other types, the AccumulatorParam must implement two methods: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9b676cae4a..2a2ff9b271 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -32,6 +32,7 @@ class RDD(object): def __init__(self, jrdd, ctx): self._jrdd = jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = ctx @property @@ -65,6 +66,7 @@ class RDD(object): (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will require recomputation. """ + self.is_checkpointed = True self._jrdd.rdd().checkpoint() def isCheckpointed(self): @@ -696,7 +698,7 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and not prev.is_cached: + if isinstance(prev, PipelinedRDD) and prev._is_pipelinable: prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) @@ -709,6 +711,7 @@ class PipelinedRDD(RDD): self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None @@ -741,6 +744,10 @@ class PipelinedRDD(RDD): self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val + @property + def _is_pipelinable(self): + return not (self.is_cached or self.is_checkpointed) + def _test(): import doctest diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c959d5dec7..83283fca4f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,6 +19,9 @@ class TestCheckpoint(unittest.TestCase): def tearDown(self): self.sc.stop() + # To avoid Akka rebinding to the same port, since it doesn't unbind + # immediately on shutdown + self.sc.jvm.System.clearProperty("spark.master.port") def test_basic_checkpointing(self): checkpointDir = NamedTemporaryFile(delete=False) @@ -41,6 +44,27 @@ class TestCheckpoint(unittest.TestCase): atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + def test_checkpoint_and_restore(self): + checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(checkpointDir.name) + self.sc.setCheckpointDir(checkpointDir.name) + + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: [x]) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + flatMappedRDD.count() # forces a checkpoint to be computed + time.sleep(1) # 1 second + + self.assertIsNotNone(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + self.assertEquals([1, 2, 3, 4], recovered.collect()) + + atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + if __name__ == "__main__": unittest.main() -- cgit v1.2.3 From 5b6ea9e9a04994553d0319c541ca356e2e3064a7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 15:31:41 -0800 Subject: Update checkpointing API docs in Python/Java. --- core/src/main/scala/spark/api/java/JavaRDDLike.scala | 17 +++++++---------- .../main/scala/spark/api/java/JavaSparkContext.scala | 17 +++++++++-------- python/pyspark/context.py | 11 +++++++---- python/pyspark/rdd.py | 17 +++++------------ 4 files changed, 28 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 087270e46d..b3698ffa44 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -307,16 +307,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] JavaPairRDD.fromRDD(rdd.keyBy(f)) } - - /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. + + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() = rdd.checkpoint() diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index fa2f14113d..14699961ad 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -357,20 +357,21 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean) { sc.setCheckpointDir(dir, useExisting) } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists, an exception will be thrown to prevent accidental + * overriding of checkpoint files. */ def setCheckpointDir(dir: String) { sc.setCheckpointDir(dir) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8beb8e2ae9..dcbed37270 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -202,9 +202,12 @@ class SparkContext(object): def setCheckpointDir(self, dirName, useExisting=False): """ - Set the directory under which RDDs are going to be checkpointed. This - method will create this directory and will throw an exception of the - path already exists (to avoid overwriting existing files may be - overwritten). The directory will be deleted on exit if indicated. + Set the directory under which RDDs are going to be checkpointed. The + directory must be a HDFS path if running on a cluster. + + If the directory does not exist, it will be created. If the directory + exists and C{useExisting} is set to true, then the exisiting directory + will be used. Otherwise an exception will be thrown to prevent + accidental overriding of checkpoint files in the existing directory. """ self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2a2ff9b271..7b6ab956ee 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -52,18 +52,11 @@ class RDD(object): def checkpoint(self): """ - Mark this RDD for checkpointing. The RDD will be saved to a file inside - `checkpointDir` (set using setCheckpointDir()) and all references to - its parent RDDs will be removed. This is used to truncate very long - lineages. In the current implementation, Spark will save this RDD to - a file (using saveAsObjectFile()) after the first job using this RDD is - done. Hence, it is strongly recommended to use checkpoint() on RDDs - when - - (i) checkpoint() is called before the any job has been executed on this - RDD. - - (ii) This RDD has been made to persist in memory. Otherwise saving it + Mark this RDD for checkpointing. It will be saved to a file inside the + checkpoint directory set with L{SparkContext.setCheckpointDir()} and + all references to its parent RDDs will be removed. This function must + be called before any job has been executed on this RDD. It is strongly + recommended that this RDD is persisted in memory, otherwise saving it on a file will require recomputation. """ self.is_checkpointed = True -- cgit v1.2.3 From 00d70cd6602d5ff2718e319ec04defbdd486237e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 15:38:11 -0800 Subject: Clean up setup code in PySpark checkpointing tests --- python/pyspark/rdd.py | 3 +-- python/pyspark/tests.py | 19 +++++-------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7b6ab956ee..097cdb13b4 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -691,7 +691,7 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and prev._is_pipelinable: + if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) @@ -737,7 +737,6 @@ class PipelinedRDD(RDD): self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val - @property def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 83283fca4f..b0a403b580 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2,7 +2,6 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ -import atexit import os import shutil from tempfile import NamedTemporaryFile @@ -16,18 +15,18 @@ class TestCheckpoint(unittest.TestCase): def setUp(self): self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + self.checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) def tearDown(self): self.sc.stop() # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") + shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): - checkpointDir = NamedTemporaryFile(delete=False) - os.unlink(checkpointDir.name) - self.sc.setCheckpointDir(checkpointDir.name) - parCollection = self.sc.parallelize([1, 2, 3, 4]) flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) @@ -39,16 +38,10 @@ class TestCheckpoint(unittest.TestCase): time.sleep(1) # 1 second self.assertTrue(flatMappedRDD.isCheckpointed()) self.assertEqual(flatMappedRDD.collect(), result) - self.assertEqual(checkpointDir.name, + self.assertEqual(self.checkpointDir.name, os.path.dirname(flatMappedRDD.getCheckpointFile())) - atexit.register(lambda: shutil.rmtree(checkpointDir.name)) - def test_checkpoint_and_restore(self): - checkpointDir = NamedTemporaryFile(delete=False) - os.unlink(checkpointDir.name) - self.sc.setCheckpointDir(checkpointDir.name) - parCollection = self.sc.parallelize([1, 2, 3, 4]) flatMappedRDD = parCollection.flatMap(lambda x: [x]) @@ -63,8 +56,6 @@ class TestCheckpoint(unittest.TestCase): recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) self.assertEquals([1, 2, 3, 4], recovered.collect()) - atexit.register(lambda: shutil.rmtree(checkpointDir.name)) - if __name__ == "__main__": unittest.main() -- cgit v1.2.3 From 9f211dd3f0132daf72fb39883fa4b28e4fd547ca Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 14 Jan 2013 15:30:42 -0800 Subject: Fix PythonPartitioner equality; see SPARK-654. PythonPartitioner did not take the Python-side partitioning function into account when checking for equality, which might cause problems in the future. --- .../main/scala/spark/api/python/PythonPartitioner.scala | 13 +++++++++++-- core/src/main/scala/spark/api/python/PythonRDD.scala | 5 ----- python/pyspark/rdd.py | 17 +++++++++++------ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 648d9402b0..519e310323 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -6,8 +6,17 @@ import java.util.Arrays /** * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + * + * Stores the unique id() of the Python-side partitioning function so that it is incorporated into + * equality comparisons. Correctness requires that the id is a unique identifier for the + * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning + * function). This can be ensured by using the Python id() function and maintaining a reference + * to the Python partitioning function so that its id() is not reused. */ -private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner { +private[spark] class PythonPartitioner( + override val numPartitions: Int, + val pyPartitionFunctionId: Long) + extends Partitioner { override def getPartition(key: Any): Int = { if (key == null) { @@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends override def equals(other: Any): Boolean = other match { case h: PythonPartitioner => - h.numPartitions == numPartitions + h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId case _ => false } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc..e4c0530241 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -252,11 +252,6 @@ private object Pickle { val APPENDS: Byte = 'e' } -private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], - Array[Byte]), Array[Byte]] { - override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 -} - private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e1..b58bf24e3e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -33,6 +33,7 @@ class RDD(object): self._jrdd = jrdd self.is_cached = False self.ctx = ctx + self._partitionFunc = None @property def context(self): @@ -497,7 +498,7 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) # TODO: add option to control map-side combining - def partitionBy(self, numSplits, hashFunc=hash): + def partitionBy(self, numSplits, partitionFunc=hash): """ Return a copy of the RDD partitioned using the specified partitioner. @@ -514,17 +515,21 @@ class RDD(object): def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: - buckets[hashFunc(k) % numSplits].append((k, v)) + buckets[partitionFunc(k) % numSplits].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - jrdd = pairRDD.partitionBy(partitioner) - jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + partitioner = self.ctx.jvm.PythonPartitioner(numSplits, + id(partitionFunc)) + jrdd = pairRDD.partitionBy(partitioner).values() + rdd = RDD(jrdd, self.ctx) + # This is required so that id(partitionFunc) remains unique, even if + # partitionFunc is a lambda: + rdd._partitionFunc = partitionFunc + return rdd # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, -- cgit v1.2.3 From 6e3754bf4759ab3e1e1be978b6b84e6f17742106 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 19:22:24 -0800 Subject: Add Maven build file for streaming, and fix some issues in SBT file As part of this, changed our Scala 2.9.2 Kafka library to be available as a local Maven repository, following the example in (http://blog.dub.podval.org/2010/01/maven-in-project-repository.html) --- examples/pom.xml | 17 +++ pom.xml | 12 ++ project/SparkBuild.scala | 16 ++- repl/pom.xml | 14 ++ streaming/lib/kafka-0.7.2.jar | Bin 1358063 -> 0 bytes .../kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar | Bin 0 -> 1358063 bytes .../kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 | 1 + .../kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 | 1 + .../kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom | 9 ++ .../kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 | 1 + .../kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 | 1 + .../apache/kafka/kafka/maven-metadata-local.xml | 12 ++ .../kafka/kafka/maven-metadata-local.xml.md5 | 1 + .../kafka/kafka/maven-metadata-local.xml.sha1 | 1 + streaming/pom.xml | 155 +++++++++++++++++++++ 15 files changed, 234 insertions(+), 7 deletions(-) delete mode 100644 streaming/lib/kafka-0.7.2.jar create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 create mode 100644 streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml create mode 100644 streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 create mode 100644 streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 create mode 100644 streaming/pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index 3355deb6b7..4d43103475 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -19,6 +19,11 @@ org.eclipse.jetty jetty-server + + org.twitter4j + twitter4j-stream + 3.0.3 + org.scalatest @@ -57,6 +62,12 @@ ${project.version} hadoop1 + + org.spark-project + spark-streaming + ${project.version} + hadoop1 + org.apache.hadoop hadoop-core @@ -90,6 +101,12 @@ ${project.version} hadoop2 + + org.spark-project + spark-streaming + ${project.version} + hadoop2 + org.apache.hadoop hadoop-core diff --git a/pom.xml b/pom.xml index 751189a9d8..483b0f9595 100644 --- a/pom.xml +++ b/pom.xml @@ -41,6 +41,7 @@ core bagel examples + streaming repl repl-bin @@ -104,6 +105,17 @@ false + + twitter4j-repo + Twitter4J Repository + http://twitter4j.org/maven2/ + + true + + + false + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 3dbb993f9c..03b8094f7d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -21,7 +21,7 @@ object SparkBuild extends Build { lazy val core = Project("core", file("core"), settings = coreSettings) - lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) + lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) dependsOn (streaming) lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming) @@ -92,8 +92,7 @@ object SparkBuild extends Build { "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011", "org.scalatest" %% "scalatest" % "1.8" % "test", "org.scalacheck" %% "scalacheck" % "1.9" % "test", - "com.novocode" % "junit-interface" % "0.8" % "test", - "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" + "com.novocode" % "junit-interface" % "0.8" % "test" ), parallelExecution := false, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ @@ -136,8 +135,6 @@ object SparkBuild extends Build { "com.typesafe.akka" % "akka-slf4j" % "2.0.3", "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", - "org.twitter4j" % "twitter4j-core" % "3.0.2", - "org.twitter4j" % "twitter4j-stream" % "3.0.2", "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", @@ -156,7 +153,10 @@ object SparkBuild extends Build { ) def examplesSettings = sharedSettings ++ Seq( - name := "spark-examples" + name := "spark-examples", + libraryDependencies ++= Seq( + "org.twitter4j" % "twitter4j-stream" % "3.0.3" + ) ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") @@ -164,7 +164,9 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", libraryDependencies ++= Seq( - "com.github.sgroschupf" % "zkclient" % "0.1") + "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile", + "com.github.sgroschupf" % "zkclient" % "0.1" + ) ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( diff --git a/repl/pom.xml b/repl/pom.xml index 38e883c7f8..2fc9692969 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -101,6 +101,13 @@ hadoop1 runtime + + org.spark-project + spark-streaming + ${project.version} + hadoop1 + runtime + org.apache.hadoop hadoop-core @@ -151,6 +158,13 @@ hadoop2 runtime + + org.spark-project + spark-streaming + ${project.version} + hadoop2 + runtime + org.apache.hadoop hadoop-core diff --git a/streaming/lib/kafka-0.7.2.jar b/streaming/lib/kafka-0.7.2.jar deleted file mode 100644 index 65f79925a4..0000000000 Binary files a/streaming/lib/kafka-0.7.2.jar and /dev/null differ diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar new file mode 100644 index 0000000000..65f79925a4 Binary files /dev/null and b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar differ diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 new file mode 100644 index 0000000000..29f45f4adb --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 @@ -0,0 +1 @@ +18876b8bc2e4cef28b6d191aa49d963f \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 new file mode 100644 index 0000000000..e3bd62bac0 --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 @@ -0,0 +1 @@ +06b27270ffa52250a2c08703b397c99127b72060 \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom new file mode 100644 index 0000000000..082d35726a --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom @@ -0,0 +1,9 @@ + + + 4.0.0 + org.apache.kafka + kafka + 0.7.2-spark + POM was created from install:install-file + diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 new file mode 100644 index 0000000000..92c4132b5b --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 @@ -0,0 +1 @@ +7bc4322266e6032bdf9ef6eebdd8097d \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 new file mode 100644 index 0000000000..8a1d8a097a --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 @@ -0,0 +1 @@ +d0f79e8eff0db43ca7bcf7dce2c8cd2972685c9d \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml new file mode 100644 index 0000000000..720cd51c2f --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml @@ -0,0 +1,12 @@ + + + org.apache.kafka + kafka + + 0.7.2-spark + + 0.7.2-spark + + 20130121015225 + + diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 new file mode 100644 index 0000000000..a4ce5dc9e8 --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 @@ -0,0 +1 @@ +e2b9c7c5f6370dd1d21a0aae5e8dcd77 \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 new file mode 100644 index 0000000000..b869eaf2a6 --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 @@ -0,0 +1 @@ +2a4341da936b6c07a09383d17ffb185ac558ee91 \ No newline at end of file diff --git a/streaming/pom.xml b/streaming/pom.xml new file mode 100644 index 0000000000..3dae815e1a --- /dev/null +++ b/streaming/pom.xml @@ -0,0 +1,155 @@ + + + 4.0.0 + + org.spark-project + parent + 0.7.0-SNAPSHOT + ../pom.xml + + + org.spark-project + spark-streaming + jar + Spark Project Streaming + http://spark-project.org/ + + + + + lib + file://${project.basedir}/lib + + + + + + org.eclipse.jetty + jetty-server + + + org.codehaus.jackson + jackson-mapper-asl + 1.9.11 + + + org.apache.kafka + kafka + 0.7.2-spark + + + org.apache.flume + flume-ng-sdk + 1.2.0 + + + com.github.sgroschupf + zkclient + 0.1 + + + + org.scalatest + scalatest_${scala.version} + test + + + org.scalacheck + scalacheck_${scala.version} + test + + + com.novocode + junit-interface + test + + + org.slf4j + slf4j-log4j12 + test + + + + target/scala-${scala.version}/classes + target/scala-${scala.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + + + + + hadoop1 + + + !hadoopVersion + + + + + org.spark-project + spark-core + ${project.version} + hadoop1 + + + org.apache.hadoop + hadoop-core + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop1 + + + + + + + hadoop2 + + + hadoopVersion + 2 + + + + + org.spark-project + spark-core + ${project.version} + hadoop2 + + + org.apache.hadoop + hadoop-core + provided + + + org.apache.hadoop + hadoop-client + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2 + + + + + + + -- cgit v1.2.3 From 4750907c3dad4c275d0f51937a098ba856098b96 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 21:05:17 -0800 Subject: Update run script to deal with change to build of REPL shaded JAR --- run | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/run b/run index 060856007f..a094629449 100755 --- a/run +++ b/run @@ -89,9 +89,11 @@ if [ -e "$FWDIR/lib_managed" ]; then CLASSPATH+=":$FWDIR/lib_managed/bundles/*" fi CLASSPATH+=":$REPL_DIR/lib/*" -for jar in `find "$REPL_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do - CLASSPATH+=":$jar" -done +if [ -e repl-bin/target ]; then + for jar in `find "repl-bin/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do + CLASSPATH+=":$jar" + done +fi CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do CLASSPATH+=":$jar" -- cgit v1.2.3 From c0b9ceb8c3d56c6d6f6f6b5925c87abad06be646 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 00:23:53 -0800 Subject: Log remote lifecycle events in Akka for easier debugging --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e67cb0336d..fbd0ff46bf 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -32,6 +32,7 @@ private[spark] object AkkaUtils { akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] akka.actor.provider = "akka.remote.RemoteActorRefProvider" akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + akka.remote.log-remote-lifecycle-events = on akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = %ds -- cgit v1.2.3 From 69a417858bf1627de5220d41afba64853d4bf64d Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 12:42:11 -0600 Subject: Also use hadoopConfiguration in newAPI methods. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 ++-- core/src/main/scala/spark/SparkContext.scala | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 51c15837c4..1c18736805 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -494,7 +494,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration) + saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) } /** @@ -506,7 +506,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration) { + conf: Configuration = self.context.hadoopConfiguration) { val job = new NewAPIHadoopJob(conf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index f6b98c41bc..303e5081a4 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -293,8 +293,7 @@ class SparkContext( path, fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]], - new Configuration(hadoopConfiguration)) + vm.erasure.asInstanceOf[Class[V]]) } /** @@ -306,7 +305,7 @@ class SparkContext( fClass: Class[F], kClass: Class[K], vClass: Class[V], - conf: Configuration): RDD[(K, V)] = { + conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -318,7 +317,7 @@ class SparkContext( * and extra configuration options to pass to the input format. */ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( - conf: Configuration, + conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { -- cgit v1.2.3 From f116d6b5c6029c2f96160bd84829a6fe8b73cccf Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 18 Jan 2013 13:24:37 -0800 Subject: executor can use a different sparkHome from Worker --- core/src/main/scala/spark/deploy/DeployMessage.scala | 4 +++- core/src/main/scala/spark/deploy/JobDescription.scala | 5 ++++- core/src/main/scala/spark/deploy/client/TestClient.scala | 3 ++- core/src/main/scala/spark/deploy/master/Master.scala | 9 +++++---- core/src/main/scala/spark/deploy/worker/Worker.scala | 4 ++-- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 ++- 6 files changed, 18 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 457122745b..7ee3e63429 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -5,6 +5,7 @@ import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List import scala.collection.mutable.HashMap +import java.io.File private[spark] sealed trait DeployMessage extends Serializable @@ -42,7 +43,8 @@ private[spark] case class LaunchExecutor( execId: Int, jobDesc: JobDescription, cores: Int, - memory: Int) + memory: Int, + sparkHome: File) extends DeployMessage diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 20879c5f11..7f8f9af417 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -1,10 +1,13 @@ package spark.deploy +import java.io.File + private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, - val command: Command) + val command: Command, + val sparkHome: File) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index 57a7e123b7..dc743b1fbf 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -3,6 +3,7 @@ package spark.deploy.client import spark.util.AkkaUtils import spark.{Logging, Utils} import spark.deploy.{Command, JobDescription} +import java.io.File private[spark] object TestClient { @@ -25,7 +26,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map())) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), new File("dummy-spark-home")) val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 6ecebe626a..f0bee67159 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -6,6 +6,7 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, Remote import java.text.SimpleDateFormat import java.util.Date +import java.io.File import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -173,7 +174,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor for (pos <- 0 until numUsable) { if (assigned(pos) > 0) { val exec = job.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) + launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -186,7 +187,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val coresToUse = math.min(worker.coresFree, job.coresLeft) if (coresToUse > 0) { val exec = job.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) + launchExecutor(worker, exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -195,10 +196,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 7c9e588ea2..078b2d8037 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -119,10 +119,10 @@ private[spark] class Worker( logError("Worker registration failed: " + message) System.exit(1) - case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) => + case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir) + jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, execSparkHome_, workDir) executors(jobId + "/" + execId) = manager manager.start() coresUsed += cores_ diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index e2301347e5..0dcc2efaca 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -4,6 +4,7 @@ import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} import scala.collection.mutable.HashMap +import java.io.File private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, @@ -39,7 +40,7 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.sparkHome)) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From aae5a920a4db0c31918a65a03ce7d2087826fd65 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 18 Jan 2013 13:28:50 -0800 Subject: get sparkHome the correct way --- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 0dcc2efaca..08b9d6ff47 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -40,7 +40,7 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.sparkHome)) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.getSparkHome())) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From 5bf73df7f08b17719711a5f05f0b3390b4951272 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sat, 19 Jan 2013 13:26:15 -0800 Subject: oops, fix stupid compile error --- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 08b9d6ff47..94886d3941 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -40,7 +40,8 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.getSparkHome())) + val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sparkHome)) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From c73107500e0a5b6c5f0b4aba8c4504ee4c2adbaf Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 20 Jan 2013 21:55:50 -0800 Subject: send sparkHome as String instead of File over network --- core/src/main/scala/spark/deploy/DeployMessage.scala | 2 +- core/src/main/scala/spark/deploy/master/Master.scala | 2 +- core/src/main/scala/spark/deploy/worker/Worker.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 7ee3e63429..a4081ef89c 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -44,7 +44,7 @@ private[spark] case class LaunchExecutor( jobDesc: JobDescription, cores: Int, memory: Int, - sparkHome: File) + sparkHome: String) extends DeployMessage diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index f0bee67159..1b6f808a51 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -199,7 +199,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome.getAbsolutePath) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 078b2d8037..19bf2be118 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -122,7 +122,7 @@ private[spark] class Worker( case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, execSparkHome_, workDir) + jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) executors(jobId + "/" + execId) = manager manager.start() coresUsed += cores_ -- cgit v1.2.3 From fe26acc482f358bf87700f5e80160f7ce558cea7 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 20 Jan 2013 21:57:44 -0800 Subject: remove unused imports --- core/src/main/scala/spark/deploy/DeployMessage.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index a4081ef89c..35f40c6e91 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -4,8 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List -import scala.collection.mutable.HashMap -import java.io.File private[spark] sealed trait DeployMessage extends Serializable -- cgit v1.2.3 From a3f571b539ffd126e9f3bc3e9c7bedfcb6f4d2d4 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 Jan 2013 10:52:17 -0800 Subject: more File -> String changes --- core/src/main/scala/spark/deploy/JobDescription.scala | 4 +--- core/src/main/scala/spark/deploy/client/TestClient.scala | 3 +-- core/src/main/scala/spark/deploy/master/Master.scala | 5 ++--- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 4 +--- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 7f8f9af417..7160fc05fc 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -1,13 +1,11 @@ package spark.deploy -import java.io.File - private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, val command: Command, - val sparkHome: File) + val sparkHome: String) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index dc743b1fbf..8764c400e2 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -3,7 +3,6 @@ package spark.deploy.client import spark.util.AkkaUtils import spark.{Logging, Utils} import spark.deploy.{Command, JobDescription} -import java.io.File private[spark] object TestClient { @@ -26,7 +25,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), new File("dummy-spark-home")) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home") val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 1b6f808a51..2c2cd0231b 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -6,7 +6,6 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, Remote import java.text.SimpleDateFormat import java.util.Date -import java.io.File import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -196,10 +195,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome.getAbsolutePath) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 94886d3941..a21a5b2f3d 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -3,8 +3,6 @@ package spark.scheduler.cluster import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} -import scala.collection.mutable.HashMap -import java.io.File private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, @@ -41,7 +39,7 @@ private[spark] class SparkDeploySchedulerBackend( val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sparkHome)) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() -- cgit v1.2.3 From 4d34c7fc3ecd7a4d035005f84c01e6990c0c345e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 11:33:48 -0800 Subject: Fix compile error caused by cherry-pick --- .../main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index a21a5b2f3d..4f82cd96dd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -3,6 +3,7 @@ package spark.scheduler.cluster import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} +import scala.collection.mutable.HashMap private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, -- cgit v1.2.3 From a88b44ed3b670633549049e9ccf990ea455e9720 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 11:59:21 -0800 Subject: Only bind to IPv4 addresses when trying to auto-detect external IP --- core/src/main/scala/spark/Utils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index b3421df27c..692a3f4050 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,7 +1,7 @@ package spark import java.io._ -import java.net.{NetworkInterface, InetAddress, URL, URI} +import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI} import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration @@ -251,7 +251,8 @@ private object Utils extends Logging { // Address resolves to something like 127.0.1.1, which happens on Debian; try to find // a better address using the local network interfaces for (ni <- NetworkInterface.getNetworkInterfaces) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) { + for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && + !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { // We've found an address that looks reasonable! logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + -- cgit v1.2.3 From 2173f6c7cac877a3b756d63aabf7bdd06a18e6d9 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 13:02:40 -0800 Subject: Clarify the documentation on env variables for standalone mode --- docs/spark-standalone.md | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index e0ba7c35cb..bf296221b8 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -51,11 +51,11 @@ Finally, the following configuration options can be passed to the master and wor -c CORES, --cores CORES - Number of CPU cores to use (default: all available); only on worker + Total CPU cores to allow Spark jobs to use on the machine (default: all available); only on worker -m MEM, --memory MEM - Amount of memory to use, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker + Total amount of memory to allow Spark jobs to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker -d DIR, --work-dir DIR @@ -66,9 +66,20 @@ Finally, the following configuration options can be passed to the master and wor # Cluster Launch Scripts -To launch a Spark standalone cluster with the deploy scripts, you need to set up two files, `conf/spark-env.sh` and `conf/slaves`. The `conf/spark-env.sh` file lets you specify global settings for the master and slave instances, such as memory, or port numbers to bind to, while `conf/slaves` is a list of slave nodes. The system requires that all the slave machines have the same configuration files, so *copy these files to each machine*. +To launch a Spark standalone cluster with the deploy scripts, you need to create a file called `conf/slaves` in your Spark directory, which should contain the hostnames of all the machines where you would like to start Spark workers, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing, you can just put `localhost` in this file. -In `conf/spark-env.sh`, you can set the following parameters, in addition to the [standard Spark configuration settings](configuration.html): +Once you've set up this fine, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: + +- `bin/start-master.sh` - Starts a master instance on the machine the script is executed on. +- `bin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. +- `bin/start-all.sh` - Starts both a master and a number of slaves as described above. +- `bin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script. +- `bin/stop-slaves.sh` - Stops the slave instances that were started via `bin/start-slaves.sh`. +- `bin/stop-all.sh` - Stops both the master and the slaves as described above. + +Note that these scripts must be executed on the machine you want to run the Spark master on, not your local machine. + +You can optionally configure the cluster further by setting environment variables in `conf/spark-env.sh`. Create this file by starting with the `conf/spark-env.sh.template`, and _copy it to all your worker machines_ for the settings to take effect. The following settings are available: @@ -88,36 +99,24 @@ In `conf/spark-env.sh`, you can set the following parameters, in addition to the + + + + - + - + - - - -
    Environment VariableMeaning
    SPARK_WORKER_PORT Start the Spark worker on a specific port (default: random)
    SPARK_WORKER_DIRDirectory to run jobs in, which will include both logs and scratch space (default: SPARK_HOME/work)
    SPARK_WORKER_CORESNumber of cores to use (default: all available cores)Total number of cores to allow Spark jobs to use on the machine (default: all available cores)
    SPARK_WORKER_MEMORYHow much memory to use, e.g. 1000M, 2G (default: total memory minus 1 GB)Total amount of memory to allow Spark jobs to use on the machine, e.g. 1000M, 2G (default: total memory minus 1 GB); note that each job's individual memory is configured using SPARK_MEM
    SPARK_WORKER_WEBUI_PORT Port for the worker web UI (default: 8081)
    SPARK_WORKER_DIRDirectory to run jobs in, which will include both logs and scratch space (default: SPARK_HOME/work)
    -In `conf/slaves`, include a list of all machines where you would like to start a Spark worker, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing purposes, you can have a single `localhost` entry in the slaves file. - -Once you've set up these configuration files, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: - -- `bin/start-master.sh` - Starts a master instance on the machine the script is executed on. -- `bin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. -- `bin/start-all.sh` - Starts both a master and a number of slaves as described above. -- `bin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script. -- `bin/stop-slaves.sh` - Stops the slave instances that were started via `bin/start-slaves.sh`. -- `bin/stop-all.sh` - Stops both the master and the slaves as described above. - -Note that the scripts must be executed on the machine you want to run the Spark master on, not your local machine. # Connecting a Job to the Cluster -- cgit v1.2.3 From 76d7c0ce2bd9c4d5782fec320279e0a011230625 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 13:10:02 -0800 Subject: Add more Akka settings to docs --- docs/configuration.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 87cb4a6797..036a0df480 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -197,6 +197,15 @@ Apart from these, the following properties are also available, and may be useful poor data locality, but the default generally works well. + + spark.akka.frameSize + 10 + + Maximum message size to allow in "control plane" communication (for serialized tasks and task + results), in MB. Increase this if your tasks need to send back large results to the master + (e.g. using collect() on a large dataset). + + spark.akka.threads 4 @@ -205,6 +214,13 @@ Apart from these, the following properties are also available, and may be useful when the master has a lot of CPU cores. + + spark.akka.timeout + 20 + + Communication timeout between Spark nodes. + + spark.master.host (local hostname) -- cgit v1.2.3 From ffd1623595cdce4080ad1e4e676e65898ebdd6dd Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 15:55:46 -0600 Subject: Minor cleanup. --- core/src/main/scala/spark/Accumulators.scala | 3 +-- core/src/main/scala/spark/Logging.scala | 3 +-- core/src/main/scala/spark/ParallelCollection.scala | 15 +++++---------- core/src/main/scala/spark/TaskContext.scala | 3 +-- core/src/main/scala/spark/rdd/BlockRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/CartesianRDD.scala | 3 +-- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/SampledRDD.scala | 5 ++--- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 3 +-- core/src/main/scala/spark/rdd/UnionRDD.scala | 3 +-- core/src/main/scala/spark/rdd/ZippedRDD.scala | 3 +-- .../scala/spark/scheduler/local/LocalScheduler.scala | 4 ++-- .../scheduler/mesos/CoarseMesosSchedulerBackend.scala | 16 ++++++---------- .../spark/scheduler/mesos/MesosSchedulerBackend.scala | 10 +++------- core/src/test/scala/spark/FileServerSuite.scala | 4 ++-- 16 files changed, 33 insertions(+), 60 deletions(-) diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index b644aba5f8..57c6df35be 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -25,8 +25,7 @@ class Accumulable[R, T] ( extends Serializable { val id = Accumulators.newId - @transient - private var value_ = initialValue // Current value on master + @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 90bae26202..7c1c1bb144 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine - @transient - private var log_ : Logger = null + @transient private var log_ : Logger = null // Method to get or create the logger for this object protected def log: Logger = { diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ede933c9e9..ad23e5bec8 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -23,32 +23,28 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - @transient sc : SparkContext, + @transient sc: SparkContext, @transient data: Seq[T], numSlices: Int, - locationPrefs : Map[Int,Seq[String]]) + locationPrefs: Map[Int,Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val slices = ParallelCollection.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } - override def getSplits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_ override def compute(s: Split, context: TaskContext) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def getPreferredLocations(s: Split): Seq[String] = { - locationPrefs.get(s.index) match { - case Some(s) => s - case _ => Nil - } + locationPrefs.get(s.index) getOrElse Nil } override def clearDependencies() { @@ -56,7 +52,6 @@ private[spark] class ParallelCollection[T: ClassManifest]( } } - private object ParallelCollection { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala index d2746b26b3..eab85f85a2 100644 --- a/core/src/main/scala/spark/TaskContext.scala +++ b/core/src/main/scala/spark/TaskContext.scala @@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable { - @transient - val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] // Add a callback function to be executed on task completion. An example use // is for HadoopRDD to register a callback to close the input stream. diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index b1095a52b4..2c022f88e0 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -11,13 +11,11 @@ private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) extends RDD[T](sc, Nil) { - @transient - var splits_ : Array[Split] = (0 until blockIds.size).map(i => { + @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] }).toArray - @transient - lazy val locations_ = { + @transient lazy val locations_ = { val blockManager = SparkEnv.get.blockManager /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ val locations = blockManager.getLocations(blockIds) diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 79e7c24e7c..453d410ad4 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -35,8 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val numSplitsInRdd2 = rdd2.splits.size - @transient - var splits_ = { + @transient var splits_ = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 1d528be2aa..8fafd27bb6 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -45,8 +45,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) val aggr = new CoGroupAggregator - @transient - var deps_ = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { if (rdd.partitioner == Some(part)) { @@ -63,8 +62,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) override def getDependencies = deps_ - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index bb22db073c..c3b155fcbd 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -37,11 +37,9 @@ class NewHadoopRDD[K, V]( formatter.format(new Date()) } - @transient - private val jobId = new JobID(jobtrackerId, id) + @transient private val jobId = new JobID(jobtrackerId, id) - @transient - private val splits_ : Array[Split] = { + @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance val jobContext = newJobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 1bc9c96112..e24ad23b21 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -19,13 +19,12 @@ class SampledRDD[T: ClassManifest]( seed: Int) extends RDD[T](prev) { - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val rg = new Random(seed) firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } - override def getSplits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_ override def getPreferredLocations(split: Split) = firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 1b219473e0..28ff19876d 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -22,8 +22,7 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) - @transient - var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def getSplits = splits_ diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 24a085df02..82f0a44ecd 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -28,8 +28,7 @@ class UnionRDD[T: ClassManifest]( @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index 16e6cc0f1b..d950b06c85 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -34,8 +34,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( // TODO: FIX THIS. - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { if (rdd1.splits.size != rdd2.splits.size) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..21d255debd 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -19,8 +19,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon extends TaskScheduler with Logging { - var attemptId = new AtomicInteger(0) - var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + val attemptId = new AtomicInteger(0) + val threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) val env = SparkEnv.get var listener: TaskSchedulerListener = null diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index c45c7df69c..014906b028 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -64,13 +64,9 @@ private[spark] class CoarseMesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Int, String] val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt @@ -184,7 +180,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { + private def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } @@ -193,7 +189,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Build a Mesos resource protobuf object */ - def createResource(resourceName: String, quantity: Double): Protos.Resource = { + private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() .setName(resourceName) .setType(Value.Type.SCALAR) @@ -202,7 +198,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { + private def isFinished(state: MesosTaskState) = { state == MesosTaskState.TASK_FINISHED || state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_KILLED || diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 8c7a1dfbc0..2989e31f5e 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -76,13 +76,9 @@ private[spark] class MesosSchedulerBackend( } def createExecutorInfo(): ExecutorInfo = { - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val execScript = new File(sparkHome, "spark-executor").getCanonicalPath val environment = Environment.newBuilder() sc.executorEnvs.foreach { case (key, value) => diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..fe964bd893 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -9,8 +9,8 @@ import SparkContext._ class FileServerSuite extends FunSuite with BeforeAndAfter { @transient var sc: SparkContext = _ - @transient var tmpFile : File = _ - @transient var testJarFile : File = _ + @transient var tmpFile: File = _ + @transient var testJarFile: File = _ before { // Create a sample text file -- cgit v1.2.3 From e5ca2413352510297092384eda73049ad601fd8a Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 16:06:58 -0600 Subject: Move JavaAPISuite into spark.streaming. --- streaming/src/test/java/JavaAPISuite.java | 1029 -------------------- streaming/src/test/java/JavaTestUtils.scala | 65 -- .../test/java/spark/streaming/JavaAPISuite.java | 1029 ++++++++++++++++++++ .../test/java/spark/streaming/JavaTestUtils.scala | 65 ++ 4 files changed, 1094 insertions(+), 1094 deletions(-) delete mode 100644 streaming/src/test/java/JavaAPISuite.java delete mode 100644 streaming/src/test/java/JavaTestUtils.scala create mode 100644 streaming/src/test/java/spark/streaming/JavaAPISuite.java create mode 100644 streaming/src/test/java/spark/streaming/JavaTestUtils.scala diff --git a/streaming/src/test/java/JavaAPISuite.java b/streaming/src/test/java/JavaAPISuite.java deleted file mode 100644 index c84e7331c7..0000000000 --- a/streaming/src/test/java/JavaAPISuite.java +++ /dev/null @@ -1,1029 +0,0 @@ -package spark.streaming; - -import com.google.common.base.Optional; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; -import com.google.common.io.Files; -import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import scala.Tuple2; -import spark.HashPartitioner; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.*; -import spark.storage.StorageLevel; -import spark.streaming.api.java.JavaDStream; -import spark.streaming.api.java.JavaPairDStream; -import spark.streaming.api.java.JavaStreamingContext; -import spark.streaming.JavaTestUtils; -import spark.streaming.JavaCheckpointTestUtils; -import spark.streaming.dstream.KafkaPartitionKey; - -import java.io.*; -import java.util.*; - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { - private transient JavaStreamingContext ssc; - - @Before - public void setUp() { - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); - ssc.checkpoint("checkpoint", new Duration(1000)); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port"); - } - - @Test - public void testCount() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3,4), - Arrays.asList(3,4,5), - Arrays.asList(3)); - - List> expected = Arrays.asList( - Arrays.asList(4L), - Arrays.asList(3L), - Arrays.asList(1L)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream count = stream.count(); - JavaTestUtils.attachTestOutputStream(count); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testMap() { - List> inputData = Arrays.asList( - Arrays.asList("hello", "world"), - Arrays.asList("goodnight", "moon")); - - List> expected = Arrays.asList( - Arrays.asList(5,5), - Arrays.asList(9,4)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }); - JavaTestUtils.attachTestOutputStream(letterCount); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testWindow() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6,1,2,3), - Arrays.asList(7,8,9,4,5,6), - Arrays.asList(7,8,9)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.window(new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(ssc, 4, 4); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testWindowWithSlideDuration() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9), - Arrays.asList(10,11,12), - Arrays.asList(13,14,15), - Arrays.asList(16,17,18)); - - List> expected = Arrays.asList( - Arrays.asList(1,2,3,4,5,6), - Arrays.asList(1,2,3,4,5,6,7,8,9,10,11,12), - Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18), - Arrays.asList(13,14,15,16,17,18)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(ssc, 8, 4); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testTumble() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9), - Arrays.asList(10,11,12), - Arrays.asList(13,14,15), - Arrays.asList(16,17,18)); - - List> expected = Arrays.asList( - Arrays.asList(1,2,3,4,5,6), - Arrays.asList(7,8,9,10,11,12), - Arrays.asList(13,14,15,16,17,18)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.tumble(new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(ssc, 6, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testFilter() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List> expected = Arrays.asList( - Arrays.asList("giants"), - Arrays.asList("yankees")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream filtered = stream.filter(new Function() { - @Override - public Boolean call(String s) throws Exception { - return s.contains("a"); - } - }); - JavaTestUtils.attachTestOutputStream(filtered); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testGlom() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List>> expected = Arrays.asList( - Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream glommed = stream.glom(); - JavaTestUtils.attachTestOutputStream(glommed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testMapPartitions() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List> expected = Arrays.asList( - Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream mapped = stream.mapPartitions(new FlatMapFunction, String>() { - @Override - public Iterable call(Iterator in) { - String out = ""; - while (in.hasNext()) { - out = out + in.next().toUpperCase(); - } - return Lists.newArrayList(out); - } - }); - JavaTestUtils.attachTestOutputStream(mapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - private class IntegerSum extends Function2 { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; - } - } - - private class IntegerDifference extends Function2 { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 - i2; - } - } - - @Test - public void testReduce() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(6), - Arrays.asList(15), - Arrays.asList(24)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reduced = stream.reduce(new IntegerSum()); - JavaTestUtils.attachTestOutputStream(reduced); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByWindow() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(6), - Arrays.asList(21), - Arrays.asList(39), - Arrays.asList(24)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reducedWindowed); - List> result = JavaTestUtils.runStreams(ssc, 4, 4); - - Assert.assertEquals(expected, result); - } - - @Test - public void testQueueStream() { - List> expected = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3)); - JavaRDD rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6)); - JavaRDD rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9)); - - LinkedList> rdds = Lists.newLinkedList(); - rdds.add(rdd1); - rdds.add(rdd2); - rdds.add(rdd3); - - JavaDStream stream = ssc.queueStream(rdds); - JavaTestUtils.attachTestOutputStream(stream); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - Assert.assertEquals(expected, result); - } - - @Test - public void testTransform() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9)); - - List> expected = Arrays.asList( - Arrays.asList(3,4,5), - Arrays.asList(6,7,8), - Arrays.asList(9,10,11)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream transformed = stream.transform(new Function, JavaRDD>() { - @Override - public JavaRDD call(JavaRDD in) throws Exception { - return in.map(new Function() { - @Override - public Integer call(Integer i) throws Exception { - return i + 2; - } - }); - }}); - JavaTestUtils.attachTestOutputStream(transformed); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testFlatMap() { - List> inputData = Arrays.asList( - Arrays.asList("go", "giants"), - Arrays.asList("boo", "dodgers"), - Arrays.asList("athletics")); - - List> expected = Arrays.asList( - Arrays.asList("g","o","g","i","a","n","t","s"), - Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"), - Arrays.asList("a","t","h","l","e","t","i","c","s")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { - @Override - public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); - } - }); - JavaTestUtils.attachTestOutputStream(flatMapped); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testPairFlatMap() { - List> inputData = Arrays.asList( - Arrays.asList("giants"), - Arrays.asList("dodgers"), - Arrays.asList("athletics")); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), - Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), - Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction() { - @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); - for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); - } - return out; - } - }); - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testUnion() { - List> inputData1 = Arrays.asList( - Arrays.asList(1,1), - Arrays.asList(2,2), - Arrays.asList(3,3)); - - List> inputData2 = Arrays.asList( - Arrays.asList(4,4), - Arrays.asList(5,5), - Arrays.asList(6,6)); - - List> expected = Arrays.asList( - Arrays.asList(1,1,4,4), - Arrays.asList(2,2,5,5), - Arrays.asList(3,3,6,6)); - - JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2); - JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2); - - JavaDStream unioned = stream1.union(stream2); - JavaTestUtils.attachTestOutputStream(unioned); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - /* - * Performs an order-invariant comparison of lists representing two RDD streams. This allows - * us to account for ordering variation within individual RDD's which occurs during windowing. - */ - public static void assertOrderInvariantEquals( - List> expected, List> actual) { - for (List list: expected) { - Collections.sort(list); - } - for (List list: actual) { - Collections.sort(list); - } - Assert.assertEquals(expected, actual); - } - - - // PairDStream Functions - @Test - public void testPairFilter() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = stream.map( - new PairFunction() { - @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); - } - }); - - JavaPairDStream filtered = pairStream.filter( - new Function, Boolean>() { - @Override - public Boolean call(Tuple2 in) throws Exception { - return in._1().contains("a"); - } - }); - JavaTestUtils.attachTestOutputStream(filtered); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); - - List>> stringIntKVStream = Arrays.asList( - Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), - Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); - - @Test - public void testPairGroupByKey() { - List>> inputData = stringStringKVStream; - - List>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), - Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream> grouped = pairStream.groupByKey(); - JavaTestUtils.attachTestOutputStream(grouped); - List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairReduceByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduced = pairStream.reduceByKey(new IntegerSum()); - - JavaTestUtils.attachTestOutputStream(reduced); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCombineByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream combined = pairStream.combineByKey( - new Function() { - @Override - public Integer call(Integer i) throws Exception { - return i; - } - }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); - - JavaTestUtils.attachTestOutputStream(combined); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCountByKey() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L)), - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream counted = pairStream.countByKey(); - JavaTestUtils.attachTestOutputStream(counted); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testGroupByKeyAndWindow() { - List>> inputData = stringStringKVStream; - - List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), - Arrays.asList(new Tuple2>("california", - Arrays.asList("sharks", "ducks", "dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders", "yankees", "mets"))), - Arrays.asList(new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream> groupWindowed = - pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(groupWindowed); - List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByKeyAndWindow() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testUpdateStateByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream updated = pairStream.updateStateByKey( - new Function2, Optional, Optional>(){ - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; - } - return Optional.of(out); - } - }); - JavaTestUtils.attachTestOutputStream(updated); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByKeyAndWindowWithInverse() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCountByKeyAndWindow() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L)), - Arrays.asList( - new Tuple2("california", 4L), - new Tuple2("new york", 4L)), - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream counted = - pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(counted); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testMapValues() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream mapped = pairStream.mapValues(new Function() { - @Override - public String call(String s) throws Exception { - return s.toUpperCase(); - } - }); - - JavaTestUtils.attachTestOutputStream(mapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testFlatMapValues() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - - JavaPairDStream flatMapped = pairStream.flatMapValues( - new Function>() { - @Override - public Iterable call(String in) { - List out = new ArrayList(); - out.add(in + "1"); - out.add(in + "2"); - return out; - } - }); - - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCoGroup() { - List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); - - List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); - - - List, List>>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), - Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); - - - JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream1, 1); - JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); - - JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream2, 1); - JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - - JavaPairDStream, List>> grouped = pairStream1.cogroup(pairStream2); - JavaTestUtils.attachTestOutputStream(grouped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testJoin() { - List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); - - List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); - - - List>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), - Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); - - - JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream1, 1); - JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); - - JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream2, 1); - JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - - JavaPairDStream> joined = pairStream1.join(pairStream2); - JavaTestUtils.attachTestOutputStream(joined); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCheckpointMasterRecovery() throws InterruptedException { - List> inputData = Arrays.asList( - Arrays.asList("this", "is"), - Arrays.asList("a", "test"), - Arrays.asList("counting", "letters")); - - List> expectedInitial = Arrays.asList( - Arrays.asList(4,2)); - List> expectedFinal = Arrays.asList( - Arrays.asList(1,4), - Arrays.asList(8,7)); - - - File tempDir = Files.createTempDir(); - ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); - - JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }); - JavaCheckpointTestUtils.attachTestOutputStream(letterCount); - List> initialResult = JavaTestUtils.runStreams(ssc, 1, 1); - - assertOrderInvariantEquals(expectedInitial, initialResult); - Thread.sleep(1000); - - ssc.stop(); - ssc = new JavaStreamingContext(tempDir.getAbsolutePath()); - ssc.start(); - List> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 2); - assertOrderInvariantEquals(expectedFinal, finalResult); - } - - /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD - @Test - public void testCheckpointofIndividualStream() throws InterruptedException { - List> inputData = Arrays.asList( - Arrays.asList("this", "is"), - Arrays.asList("a", "test"), - Arrays.asList("counting", "letters")); - - List> expected = Arrays.asList( - Arrays.asList(4,2), - Arrays.asList(1,4), - Arrays.asList(8,7)); - - JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(new Function() { - @Override - public Integer call(String s) throws Exception { - return s.length(); - } - }); - JavaCheckpointTestUtils.attachTestOutputStream(letterCount); - - letterCount.checkpoint(new Duration(1000)); - - List> result1 = JavaCheckpointTestUtils.runStreams(ssc, 3, 3); - assertOrderInvariantEquals(expected, result1); - } - */ - - // Input stream tests. These mostly just test that we can instantiate a given InputStream with - // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the - // InputStream functionality is deferred to the existing Scala tests. - @Test - public void testKafkaStream() { - HashMap topics = Maps.newHashMap(); - HashMap offsets = Maps.newHashMap(); - JavaDStream test1 = ssc.kafkaStream("localhost", 12345, "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets); - JavaDStream test3 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets, - StorageLevel.MEMORY_AND_DISK()); - } - - @Test - public void testNetworkTextStream() { - JavaDStream test = ssc.networkTextStream("localhost", 12345); - } - - @Test - public void testNetworkString() { - class Converter extends Function> { - public Iterable call(InputStream in) { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - try { - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - } catch (IOException e) { } - return out; - } - } - - JavaDStream test = ssc.networkStream( - "localhost", - 12345, - new Converter(), - StorageLevel.MEMORY_ONLY()); - } - - @Test - public void testTextFileStream() { - JavaDStream test = ssc.textFileStream("/tmp/foo"); - } - - @Test - public void testRawNetworkStream() { - JavaDStream test = ssc.rawNetworkStream("localhost", 12345); - } - - @Test - public void testFlumeStream() { - JavaDStream test = ssc.flumeStream("localhost", 12345); - } - - @Test - public void testFileStream() { - JavaPairDStream foo = - ssc.fileStream("/tmp/foo"); - } -} diff --git a/streaming/src/test/java/JavaTestUtils.scala b/streaming/src/test/java/JavaTestUtils.scala deleted file mode 100644 index 56349837e5..0000000000 --- a/streaming/src/test/java/JavaTestUtils.scala +++ /dev/null @@ -1,65 +0,0 @@ -package spark.streaming - -import collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import java.util.{List => JList} -import spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} -import spark.streaming._ -import java.util.ArrayList -import collection.JavaConversions._ - -/** Exposes streaming test functionality in a Java-friendly way. */ -trait JavaTestBase extends TestSuiteBase { - - /** - * Create a [[spark.streaming.TestInputStream]] and attach it to the supplied context. - * The stream will be derived from the supplied lists of Java objects. - **/ - def attachTestInputStream[T]( - ssc: JavaStreamingContext, - data: JList[JList[T]], - numPartitions: Int) = { - val seqData = data.map(Seq(_:_*)) - - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val dstream = new TestInputStream[T](ssc.ssc, seqData, numPartitions) - ssc.ssc.registerInputStream(dstream) - new JavaDStream[T](dstream) - } - - /** - * Attach a provided stream to it's associated StreamingContext as a - * [[spark.streaming.TestOutputStream]]. - **/ - def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]]( - dstream: JavaDStreamLike[T, This]) = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val ostream = new TestOutputStream(dstream.dstream, - new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]) - dstream.dstream.ssc.registerOutputStream(ostream) - } - - /** - * Process all registered streams for a numBatches batches, failing if - * numExpectedOutput RDD's are not generated. Generated RDD's are collected - * and returned, represented as a list for each batch interval. - */ - def runStreams[V]( - ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { - implicit val cm: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] - val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[V]]() - res.map(entry => out.append(new ArrayList[V](entry))) - out - } -} - -object JavaTestUtils extends JavaTestBase { - -} - -object JavaCheckpointTestUtils extends JavaTestBase { - override def actuallyWait = true -} \ No newline at end of file diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java new file mode 100644 index 0000000000..c84e7331c7 --- /dev/null +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -0,0 +1,1029 @@ +package spark.streaming; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.io.Files; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import scala.Tuple2; +import spark.HashPartitioner; +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; +import spark.api.java.function.*; +import spark.storage.StorageLevel; +import spark.streaming.api.java.JavaDStream; +import spark.streaming.api.java.JavaPairDStream; +import spark.streaming.api.java.JavaStreamingContext; +import spark.streaming.JavaTestUtils; +import spark.streaming.JavaCheckpointTestUtils; +import spark.streaming.dstream.KafkaPartitionKey; + +import java.io.*; +import java.util.*; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaStreamingContext ssc; + + @Before + public void setUp() { + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint", new Duration(1000)); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port"); + } + + @Test + public void testCount() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3,4), + Arrays.asList(3,4,5), + Arrays.asList(3)); + + List> expected = Arrays.asList( + Arrays.asList(4L), + Arrays.asList(3L), + Arrays.asList(1L)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream count = stream.count(); + JavaTestUtils.attachTestOutputStream(count); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testMap() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("goodnight", "moon")); + + List> expected = Arrays.asList( + Arrays.asList(5,5), + Arrays.asList(9,4)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaTestUtils.attachTestOutputStream(letterCount); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6,1,2,3), + Arrays.asList(7,8,9,4,5,6), + Arrays.asList(7,8,9)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream windowed = stream.window(new Duration(2000)); + JavaTestUtils.attachTestOutputStream(windowed); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testWindowWithSlideDuration() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9), + Arrays.asList(10,11,12), + Arrays.asList(13,14,15), + Arrays.asList(16,17,18)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3,4,5,6), + Arrays.asList(1,2,3,4,5,6,7,8,9,10,11,12), + Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18), + Arrays.asList(13,14,15,16,17,18)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000)); + JavaTestUtils.attachTestOutputStream(windowed); + List> result = JavaTestUtils.runStreams(ssc, 8, 4); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testTumble() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9), + Arrays.asList(10,11,12), + Arrays.asList(13,14,15), + Arrays.asList(16,17,18)); + + List> expected = Arrays.asList( + Arrays.asList(1,2,3,4,5,6), + Arrays.asList(7,8,9,10,11,12), + Arrays.asList(13,14,15,16,17,18)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream windowed = stream.tumble(new Duration(2000)); + JavaTestUtils.attachTestOutputStream(windowed); + List> result = JavaTestUtils.runStreams(ssc, 6, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("yankees")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream filtered = stream.filter(new Function() { + @Override + public Boolean call(String s) throws Exception { + return s.contains("a"); + } + }); + JavaTestUtils.attachTestOutputStream(filtered); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testGlom() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List>> expected = Arrays.asList( + Arrays.asList(Arrays.asList("giants", "dodgers")), + Arrays.asList(Arrays.asList("yankees", "red socks"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream glommed = stream.glom(); + JavaTestUtils.attachTestOutputStream(glommed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testMapPartitions() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("GIANTSDODGERS"), + Arrays.asList("YANKEESRED SOCKS")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream mapped = stream.mapPartitions(new FlatMapFunction, String>() { + @Override + public Iterable call(Iterator in) { + String out = ""; + while (in.hasNext()) { + out = out + in.next().toUpperCase(); + } + return Lists.newArrayList(out); + } + }); + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + private class IntegerSum extends Function2 { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + } + + private class IntegerDifference extends Function2 { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 - i2; + } + } + + @Test + public void testReduce() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(15), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reduced = stream.reduce(new IntegerSum()); + JavaTestUtils.attachTestOutputStream(reduced); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(21), + Arrays.asList(39), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), + new IntegerDifference(), new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reducedWindowed); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); + + Assert.assertEquals(expected, result); + } + + @Test + public void testQueueStream() { + List> expected = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); + JavaRDD rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3)); + JavaRDD rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6)); + JavaRDD rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9)); + + LinkedList> rdds = Lists.newLinkedList(); + rdds.add(rdd1); + rdds.add(rdd2); + rdds.add(rdd3); + + JavaDStream stream = ssc.queueStream(rdds); + JavaTestUtils.attachTestOutputStream(stream); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + Assert.assertEquals(expected, result); + } + + @Test + public void testTransform() { + List> inputData = Arrays.asList( + Arrays.asList(1,2,3), + Arrays.asList(4,5,6), + Arrays.asList(7,8,9)); + + List> expected = Arrays.asList( + Arrays.asList(3,4,5), + Arrays.asList(6,7,8), + Arrays.asList(9,10,11)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream transformed = stream.transform(new Function, JavaRDD>() { + @Override + public JavaRDD call(JavaRDD in) throws Exception { + return in.map(new Function() { + @Override + public Integer call(Integer i) throws Exception { + return i + 2; + } + }); + }}); + JavaTestUtils.attachTestOutputStream(transformed); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("go", "giants"), + Arrays.asList("boo", "dodgers"), + Arrays.asList("athletics")); + + List> expected = Arrays.asList( + Arrays.asList("g","o","g","i","a","n","t","s"), + Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"), + Arrays.asList("a","t","h","l","e","t","i","c","s")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(x.split("(?!^)")); + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testPairFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("dodgers"), + Arrays.asList("athletics")); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(6, "g"), + new Tuple2(6, "i"), + new Tuple2(6, "a"), + new Tuple2(6, "n"), + new Tuple2(6, "t"), + new Tuple2(6, "s")), + Arrays.asList( + new Tuple2(7, "d"), + new Tuple2(7, "o"), + new Tuple2(7, "d"), + new Tuple2(7, "g"), + new Tuple2(7, "e"), + new Tuple2(7, "r"), + new Tuple2(7, "s")), + Arrays.asList( + new Tuple2(9, "a"), + new Tuple2(9, "t"), + new Tuple2(9, "h"), + new Tuple2(9, "l"), + new Tuple2(9, "e"), + new Tuple2(9, "t"), + new Tuple2(9, "i"), + new Tuple2(9, "c"), + new Tuple2(9, "s"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction() { + @Override + public Iterable> call(String in) throws Exception { + List> out = Lists.newArrayList(); + for (String letter: in.split("(?!^)")) { + out.add(new Tuple2(in.length(), letter)); + } + return out; + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testUnion() { + List> inputData1 = Arrays.asList( + Arrays.asList(1,1), + Arrays.asList(2,2), + Arrays.asList(3,3)); + + List> inputData2 = Arrays.asList( + Arrays.asList(4,4), + Arrays.asList(5,5), + Arrays.asList(6,6)); + + List> expected = Arrays.asList( + Arrays.asList(1,1,4,4), + Arrays.asList(2,2,5,5), + Arrays.asList(3,3,6,6)); + + JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2); + JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2); + + JavaDStream unioned = stream1.union(stream2); + JavaTestUtils.attachTestOutputStream(unioned); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + /* + * Performs an order-invariant comparison of lists representing two RDD streams. This allows + * us to account for ordering variation within individual RDD's which occurs during windowing. + */ + public static void assertOrderInvariantEquals( + List> expected, List> actual) { + for (List list: expected) { + Collections.sort(list); + } + for (List list: actual) { + Collections.sort(list); + } + Assert.assertEquals(expected, actual); + } + + + // PairDStream Functions + @Test + public void testPairFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("giants", 6)), + Arrays.asList(new Tuple2("yankees", 7))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = stream.map( + new PairFunction() { + @Override + public Tuple2 call(String in) throws Exception { + return new Tuple2(in, in.length()); + } + }); + + JavaPairDStream filtered = pairStream.filter( + new Function, Boolean>() { + @Override + public Boolean call(Tuple2 in) throws Exception { + return in._1().contains("a"); + } + }); + JavaTestUtils.attachTestOutputStream(filtered); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("california", "giants"), + new Tuple2("new york", "yankees"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("california", "ducks"), + new Tuple2("new york", "rangers"), + new Tuple2("new york", "islanders"))); + + List>> stringIntKVStream = Arrays.asList( + Arrays.asList( + new Tuple2("california", 1), + new Tuple2("california", 3), + new Tuple2("new york", 4), + new Tuple2("new york", 1)), + Arrays.asList( + new Tuple2("california", 5), + new Tuple2("california", 5), + new Tuple2("new york", 3), + new Tuple2("new york", 1))); + + @Test + public void testPairGroupByKey() { + List>> inputData = stringStringKVStream; + + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", Arrays.asList("dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + Arrays.asList( + new Tuple2>("california", Arrays.asList("sharks", "ducks")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream> grouped = pairStream.groupByKey(); + JavaTestUtils.attachTestOutputStream(grouped); + List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairReduceByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList( + new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduced = pairStream.reduceByKey(new IntegerSum()); + + JavaTestUtils.attachTestOutputStream(reduced); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCombineByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList( + new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream combined = pairStream.combineByKey( + new Function() { + @Override + public Integer call(Integer i) throws Exception { + return i; + } + }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); + + JavaTestUtils.attachTestOutputStream(combined); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCountByKey() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L)), + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream counted = pairStream.countByKey(); + JavaTestUtils.attachTestOutputStream(counted); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testGroupByKeyAndWindow() { + List>> inputData = stringStringKVStream; + + List>>> expected = Arrays.asList( + Arrays.asList(new Tuple2>("california", Arrays.asList("dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + Arrays.asList(new Tuple2>("california", + Arrays.asList("sharks", "ducks", "dodgers", "giants")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders", "yankees", "mets"))), + Arrays.asList(new Tuple2>("california", Arrays.asList("sharks", "ducks")), + new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream> groupWindowed = + pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(groupWindowed); + List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindow() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testUpdateStateByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream updated = pairStream.updateStateByKey( + new Function2, Optional, Optional>(){ + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; + } + return Optional.of(out); + } + }); + JavaTestUtils.attachTestOutputStream(updated); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindowWithInverse() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCountByKeyAndWindow() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L)), + Arrays.asList( + new Tuple2("california", 4L), + new Tuple2("new york", 4L)), + Arrays.asList( + new Tuple2("california", 2L), + new Tuple2("new york", 2L))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream counted = + pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(counted); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "DODGERS"), + new Tuple2("california", "GIANTS"), + new Tuple2("new york", "YANKEES"), + new Tuple2("new york", "METS")), + Arrays.asList(new Tuple2("california", "SHARKS"), + new Tuple2("california", "DUCKS"), + new Tuple2("new york", "RANGERS"), + new Tuple2("new york", "ISLANDERS"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream mapped = pairStream.mapValues(new Function() { + @Override + public String call(String s) throws Exception { + return s.toUpperCase(); + } + }); + + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testFlatMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers1"), + new Tuple2("california", "dodgers2"), + new Tuple2("california", "giants1"), + new Tuple2("california", "giants2"), + new Tuple2("new york", "yankees1"), + new Tuple2("new york", "yankees2"), + new Tuple2("new york", "mets1"), + new Tuple2("new york", "mets2")), + Arrays.asList(new Tuple2("california", "sharks1"), + new Tuple2("california", "sharks2"), + new Tuple2("california", "ducks1"), + new Tuple2("california", "ducks2"), + new Tuple2("new york", "rangers1"), + new Tuple2("new york", "rangers2"), + new Tuple2("new york", "islanders1"), + new Tuple2("new york", "islanders2"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + + JavaPairDStream flatMapped = pairStream.flatMapValues( + new Function>() { + @Override + public Iterable call(String in) { + List out = new ArrayList(); + out.add(in + "1"); + out.add(in + "2"); + return out; + } + }); + + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCoGroup() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("new york", "yankees")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList(new Tuple2("california", "giants"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "ducks"), + new Tuple2("new york", "islanders"))); + + + List, List>>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2, List>>("california", + new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2, List>>("new york", + new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + Arrays.asList( + new Tuple2, List>>("california", + new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2, List>>("new york", + new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream, List>> grouped = pairStream1.cogroup(pairStream2); + JavaTestUtils.attachTestOutputStream(grouped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testJoin() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("new york", "yankees")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList(new Tuple2("california", "giants"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "ducks"), + new Tuple2("new york", "islanders"))); + + + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", + new Tuple2("dodgers", "giants")), + new Tuple2>("new york", + new Tuple2("yankees", "mets"))), + Arrays.asList( + new Tuple2>("california", + new Tuple2("sharks", "ducks")), + new Tuple2>("new york", + new Tuple2("rangers", "islanders")))); + + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream> joined = pairStream1.join(pairStream2); + JavaTestUtils.attachTestOutputStream(joined); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCheckpointMasterRecovery() throws InterruptedException { + List> inputData = Arrays.asList( + Arrays.asList("this", "is"), + Arrays.asList("a", "test"), + Arrays.asList("counting", "letters")); + + List> expectedInitial = Arrays.asList( + Arrays.asList(4,2)); + List> expectedFinal = Arrays.asList( + Arrays.asList(1,4), + Arrays.asList(8,7)); + + + File tempDir = Files.createTempDir(); + ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); + + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaCheckpointTestUtils.attachTestOutputStream(letterCount); + List> initialResult = JavaTestUtils.runStreams(ssc, 1, 1); + + assertOrderInvariantEquals(expectedInitial, initialResult); + Thread.sleep(1000); + + ssc.stop(); + ssc = new JavaStreamingContext(tempDir.getAbsolutePath()); + ssc.start(); + List> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 2); + assertOrderInvariantEquals(expectedFinal, finalResult); + } + + /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD + @Test + public void testCheckpointofIndividualStream() throws InterruptedException { + List> inputData = Arrays.asList( + Arrays.asList("this", "is"), + Arrays.asList("a", "test"), + Arrays.asList("counting", "letters")); + + List> expected = Arrays.asList( + Arrays.asList(4,2), + Arrays.asList(1,4), + Arrays.asList(8,7)); + + JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(new Function() { + @Override + public Integer call(String s) throws Exception { + return s.length(); + } + }); + JavaCheckpointTestUtils.attachTestOutputStream(letterCount); + + letterCount.checkpoint(new Duration(1000)); + + List> result1 = JavaCheckpointTestUtils.runStreams(ssc, 3, 3); + assertOrderInvariantEquals(expected, result1); + } + */ + + // Input stream tests. These mostly just test that we can instantiate a given InputStream with + // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the + // InputStream functionality is deferred to the existing Scala tests. + @Test + public void testKafkaStream() { + HashMap topics = Maps.newHashMap(); + HashMap offsets = Maps.newHashMap(); + JavaDStream test1 = ssc.kafkaStream("localhost", 12345, "group", topics); + JavaDStream test2 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets); + JavaDStream test3 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets, + StorageLevel.MEMORY_AND_DISK()); + } + + @Test + public void testNetworkTextStream() { + JavaDStream test = ssc.networkTextStream("localhost", 12345); + } + + @Test + public void testNetworkString() { + class Converter extends Function> { + public Iterable call(InputStream in) { + BufferedReader reader = new BufferedReader(new InputStreamReader(in)); + List out = new ArrayList(); + try { + while (true) { + String line = reader.readLine(); + if (line == null) { break; } + out.add(line); + } + } catch (IOException e) { } + return out; + } + } + + JavaDStream test = ssc.networkStream( + "localhost", + 12345, + new Converter(), + StorageLevel.MEMORY_ONLY()); + } + + @Test + public void testTextFileStream() { + JavaDStream test = ssc.textFileStream("/tmp/foo"); + } + + @Test + public void testRawNetworkStream() { + JavaDStream test = ssc.rawNetworkStream("localhost", 12345); + } + + @Test + public void testFlumeStream() { + JavaDStream test = ssc.flumeStream("localhost", 12345); + } + + @Test + public void testFileStream() { + JavaPairDStream foo = + ssc.fileStream("/tmp/foo"); + } +} diff --git a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala new file mode 100644 index 0000000000..56349837e5 --- /dev/null +++ b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala @@ -0,0 +1,65 @@ +package spark.streaming + +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.{List => JList} +import spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} +import spark.streaming._ +import java.util.ArrayList +import collection.JavaConversions._ + +/** Exposes streaming test functionality in a Java-friendly way. */ +trait JavaTestBase extends TestSuiteBase { + + /** + * Create a [[spark.streaming.TestInputStream]] and attach it to the supplied context. + * The stream will be derived from the supplied lists of Java objects. + **/ + def attachTestInputStream[T]( + ssc: JavaStreamingContext, + data: JList[JList[T]], + numPartitions: Int) = { + val seqData = data.map(Seq(_:_*)) + + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val dstream = new TestInputStream[T](ssc.ssc, seqData, numPartitions) + ssc.ssc.registerInputStream(dstream) + new JavaDStream[T](dstream) + } + + /** + * Attach a provided stream to it's associated StreamingContext as a + * [[spark.streaming.TestOutputStream]]. + **/ + def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]]( + dstream: JavaDStreamLike[T, This]) = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + val ostream = new TestOutputStream(dstream.dstream, + new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]) + dstream.dstream.ssc.registerOutputStream(ostream) + } + + /** + * Process all registered streams for a numBatches batches, failing if + * numExpectedOutput RDD's are not generated. Generated RDD's are collected + * and returned, represented as a list for each batch interval. + */ + def runStreams[V]( + ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { + implicit val cm: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) + val out = new ArrayList[JList[V]]() + res.map(entry => out.append(new ArrayList[V](entry))) + out + } +} + +object JavaTestUtils extends JavaTestBase { + +} + +object JavaCheckpointTestUtils extends JavaTestBase { + override def actuallyWait = true +} \ No newline at end of file -- cgit v1.2.3 From ef711902c1f42db14c8ddd524195f0a9efb56e65 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 16:42:24 -0800 Subject: Don't download files to master's working directory. This should avoid exceptions caused by existing files with different contents. I also removed some unused code. --- core/src/main/scala/spark/HttpFileServer.scala | 8 ++--- core/src/main/scala/spark/SparkContext.scala | 7 ++-- core/src/main/scala/spark/SparkEnv.scala | 20 +++++++---- core/src/main/scala/spark/SparkFiles.java | 25 ++++++++++++++ core/src/main/scala/spark/Utils.scala | 16 +-------- .../scala/spark/api/java/JavaSparkContext.scala | 5 +-- .../main/scala/spark/api/python/PythonRDD.scala | 2 ++ .../scala/spark/deploy/worker/ExecutorRunner.scala | 5 --- core/src/main/scala/spark/executor/Executor.scala | 6 ++-- .../spark/scheduler/local/LocalScheduler.scala | 6 ++-- core/src/test/scala/spark/FileServerSuite.scala | 9 +++-- python/pyspark/__init__.py | 5 ++- python/pyspark/context.py | 40 +++++++++++++++++++--- python/pyspark/files.py | 24 +++++++++++++ python/pyspark/worker.py | 3 ++ python/run-tests | 3 ++ 16 files changed, 133 insertions(+), 51 deletions(-) create mode 100644 core/src/main/scala/spark/SparkFiles.java create mode 100644 python/pyspark/files.py diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala index 659d17718f..00901d95e2 100644 --- a/core/src/main/scala/spark/HttpFileServer.scala +++ b/core/src/main/scala/spark/HttpFileServer.scala @@ -1,9 +1,7 @@ package spark -import java.io.{File, PrintWriter} -import java.net.URL -import scala.collection.mutable.HashMap -import org.apache.hadoop.fs.FileUtil +import java.io.{File} +import com.google.common.io.Files private[spark] class HttpFileServer extends Logging { @@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging { } def addFileToDir(file: File, dir: File) : String = { - Utils.copyFile(file, new File(dir, file.getName)) + Files.copy(file, new File(dir, file.getName)) return dir + "/" + file.getName } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8b6f4b3b7d..2eeca66ed6 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -439,9 +439,10 @@ class SparkContext( def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** - * Add a file to be downloaded into the working directory of this Spark job on every node. + * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. */ def addFile(path: String) { val uri = new URI(path) @@ -454,7 +455,7 @@ class SparkContext( // Fetch the file locally in case a job is executed locally. // Jobs that run through LocalScheduler will already fetch the required dependencies, // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. - Utils.fetchFile(path, new File(".")) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 41441720a7..6b44e29f4c 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -28,14 +28,10 @@ class SparkEnv ( val broadcastManager: BroadcastManager, val blockManager: BlockManager, val connectionManager: ConnectionManager, - val httpFileServer: HttpFileServer + val httpFileServer: HttpFileServer, + val sparkFilesDir: String ) { - /** No-parameter constructor for unit tests. */ - def this() = { - this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null) - } - def stop() { httpFileServer.stop() mapOutputTracker.stop() @@ -112,6 +108,15 @@ object SparkEnv extends Logging { httpFileServer.initialize() System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + // Set the sparkFiles directory, used when downloading dependencies. In local mode, + // this is a temporary directory; in distributed mode, this is the executor's current working + // directory. + val sparkFilesDir: String = if (isMaster) { + Utils.createTempDir().getAbsolutePath + } else { + "." + } + // Warn about deprecated spark.cache.class property if (System.getProperty("spark.cache.class") != null) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -128,6 +133,7 @@ object SparkEnv extends Logging { broadcastManager, blockManager, connectionManager, - httpFileServer) + httpFileServer, + sparkFilesDir) } } diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java new file mode 100644 index 0000000000..b59d8ce93f --- /dev/null +++ b/core/src/main/scala/spark/SparkFiles.java @@ -0,0 +1,25 @@ +package spark; + +import java.io.File; + +/** + * Resolves paths to files added through `addFile(). + */ +public class SparkFiles { + + private SparkFiles() {} + + /** + * Get the absolute path of a file added through `addFile()`. + */ + public static String get(String filename) { + return new File(getRootDirectory(), filename).getAbsolutePath(); + } + + /** + * Get the root directory that contains files added through `addFile()`. + */ + public static String getRootDirectory() { + return SparkEnv.get().sparkFilesDir(); + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 692a3f4050..827c8bd81e 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -111,20 +111,6 @@ private object Utils extends Logging { } } - /** Copy a file on the local file system */ - def copyFile(source: File, dest: File) { - val in = new FileInputStream(source) - val out = new FileOutputStream(dest) - copyStream(in, out, true) - } - - /** Download a file from a given URL to the local filesystem */ - def downloadFile(url: URL, localPath: String) { - val in = url.openStream() - val out = new FileOutputStream(localPath) - Utils.copyStream(in, out, true) - } - /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. @@ -201,7 +187,7 @@ private object Utils extends Logging { Utils.execute(Seq("tar", "-xf", filename), targetDir) } // Make the file executable - That's necessary for scripts - FileUtil.chmod(filename, "a+x") + FileUtil.chmod(targetFile.getAbsolutePath, "a+x") } /** diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 16c122c584..50b8970cd8 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -323,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def getSparkHome(): Option[String] = sc.getSparkHome() /** - * Add a file to be downloaded into the working directory of this Spark job on every node. + * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. */ def addFile(path: String) { sc.addFile(path) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 5526406a20..f43a152ca7 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -67,6 +67,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val dOut = new DataOutputStream(proc.getOutputStream) // Split index dOut.writeInt(split.index) + // sparkFilesDir + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut) // Broadcast variables dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index beceb55ecd..0d1fe2a6b4 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -106,11 +106,6 @@ private[spark] class ExecutorRunner( throw new IOException("Failed to create directory " + executorDir) } - // Download the files it depends on into it (disabled for now) - //for (url <- jobDesc.fileUrls) { - // fetchFile(url, executorDir) - //} - // Launch the process val command = buildCommandSeq() val builder = new ProcessBuilder(command: _*).directory(executorDir) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 2552958d27..70629f6003 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -162,16 +162,16 @@ private[spark] class Executor extends Logging { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL if (!urlClassLoader.getURLs.contains(url)) { logInfo("Adding " + url + " to class loader") urlClassLoader.addURL(url) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..4451d314e6 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -116,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL if (!classLoader.getURLs.contains(url)) { logInfo("Adding " + url + " to class loader") classLoader.addURL(url) diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..528c6b8424 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -40,7 +40,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal @@ -54,7 +55,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile((new File(tmpFile.toString)).toURL.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal @@ -83,7 +85,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 00666bc0a3..3e8bca62f0 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -11,6 +11,8 @@ Public classes: A broadcast variable that gets reused across tasks. - L{Accumulator} An "add-only" shared variable that tasks can only add values to. + - L{SparkFiles} + Access files shipped with jobs. """ import sys import os @@ -19,6 +21,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg from pyspark.context import SparkContext from pyspark.rdd import RDD +from pyspark.files import SparkFiles -__all__ = ["SparkContext", "RDD"] +__all__ = ["SparkContext", "RDD", "SparkFiles"] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index dcbed37270..ec0cc7c2f9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,5 +1,7 @@ import os import atexit +import shutil +import tempfile from tempfile import NamedTemporaryFile from pyspark import accumulators @@ -173,10 +175,26 @@ class SparkContext(object): def addFile(self, path): """ - Add a file to be downloaded into the working directory of this Spark - job on every node. The C{path} passed can be either a local file, - a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, - HTTPS or FTP URI. + Add a file to be downloaded with this Spark job on every node. + The C{path} passed can be either a local file, a file in HDFS + (or other Hadoop-supported filesystems), or an HTTP, HTTPS or + FTP URI. + + To access the file in Spark jobs, use + L{SparkFiles.get(path)} to find its + download location. + + >>> from pyspark import SparkFiles + >>> path = os.path.join(tempdir, "test.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("100") + >>> sc.addFile(path) + >>> def func(iterator): + ... with open(SparkFiles.get("test.txt")) as testFile: + ... fileVal = int(testFile.readline()) + ... return [x * 100 for x in iterator] + >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() + [100, 200, 300, 400] """ self._jsc.sc().addFile(path) @@ -211,3 +229,17 @@ class SparkContext(object): accidental overriding of checkpoint files in the existing directory. """ self._jsc.sc().setCheckpointDir(dirName, useExisting) + + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['tempdir'] = tempfile.mkdtemp() + atexit.register(lambda: shutil.rmtree(globs['tempdir'])) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/files.py b/python/pyspark/files.py new file mode 100644 index 0000000000..de1334f046 --- /dev/null +++ b/python/pyspark/files.py @@ -0,0 +1,24 @@ +import os + + +class SparkFiles(object): + """ + Resolves paths to files added through + L{addFile()}. + + SparkFiles contains only classmethods; users should not create SparkFiles + instances. + """ + + _root_directory = None + + def __init__(self): + raise NotImplementedError("Do not construct SparkFiles objects") + + @classmethod + def get(cls, filename): + """ + Get the absolute path of a file added through C{addFile()}. + """ + path = os.path.join(SparkFiles._root_directory, filename) + return os.path.abspath(path) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b2b9288089..e7bdb7682b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -8,6 +8,7 @@ from base64 import standard_b64decode from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler +from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -23,6 +24,8 @@ def load_obj(): def main(): split_index = read_int(sys.stdin) + spark_files_dir = load_pickle(read_with_length(sys.stdin)) + SparkFiles._root_directory = spark_files_dir num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) diff --git a/python/run-tests b/python/run-tests index ce214e98a8..a3a9ff5dcb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -8,6 +8,9 @@ FAILED=0 $FWDIR/pyspark pyspark/rdd.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark pyspark/context.py +FAILED=$(($?||$FAILED)) + $FWDIR/pyspark -m doctest pyspark/broadcast.py FAILED=$(($?||$FAILED)) -- cgit v1.2.3 From 7b9e96c99206c0679d9925e0161fde738a5c7c3a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 16:45:00 -0800 Subject: Add synchronization to Executor.updateDependencies() (SPARK-662) --- core/src/main/scala/spark/executor/Executor.scala | 34 ++++++++++++----------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 70629f6003..28d9d40d43 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -159,22 +159,24 @@ private[spark] class Executor extends Logging { * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } } -- cgit v1.2.3 From 2d8218b8717435a47d7cea399290b30bf5ef010b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 20:00:27 -0600 Subject: Remove unneeded/now-broken saveAsNewAPIHadoopFile overload. --- core/src/main/scala/spark/PairRDDFunctions.scala | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 656b820b8a..53b051f1c5 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -485,18 +485,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) - } - /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. -- cgit v1.2.3 From a8baeb93272b03a98e44c7bf5c541611aec4a64b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 21:30:24 -0600 Subject: Further simplify getOrElse call. --- core/src/main/scala/spark/ParallelCollection.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ad23e5bec8..10adcd53ec 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -44,7 +44,7 @@ private[spark] class ParallelCollection[T: ClassManifest]( s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def getPreferredLocations(s: Split): Seq[String] = { - locationPrefs.get(s.index) getOrElse Nil + locationPrefs.getOrElse(s.index, Nil) } override def clearDependencies() { -- cgit v1.2.3 From 551a47a620c7dc207e3530e54d794a3c3aa8e45e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 23:31:00 -0800 Subject: Refactor daemon thread pool creation. --- .../src/main/scala/spark/DaemonThreadFactory.scala | 18 ------------ core/src/main/scala/spark/Utils.scala | 33 +++++----------------- .../scala/spark/network/ConnectionManager.scala | 5 ++-- .../spark/scheduler/local/LocalScheduler.scala | 2 +- .../spark/streaming/dstream/RawInputDStream.scala | 5 ++-- 5 files changed, 13 insertions(+), 50 deletions(-) delete mode 100644 core/src/main/scala/spark/DaemonThreadFactory.scala diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala deleted file mode 100644 index 56e59adeb7..0000000000 --- a/core/src/main/scala/spark/DaemonThreadFactory.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark - -import java.util.concurrent.ThreadFactory - -/** - * A ThreadFactory that creates daemon threads - */ -private object DaemonThreadFactory extends ThreadFactory { - override def newThread(r: Runnable): Thread = new DaemonThread(r) -} - -private class DaemonThread(r: Runnable = null) extends Thread { - override def run() { - if (r != null) { - r.run() - } - } -} \ No newline at end of file diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 692a3f4050..9b8636f6c8 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -10,6 +10,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files +import com.google.common.util.concurrent.ThreadFactoryBuilder /** * Various utility methods used by Spark. @@ -287,29 +288,14 @@ private object Utils extends Logging { customHostname.getOrElse(InetAddress.getLocalHost.getHostName) } - /** - * Returns a standard ThreadFactory except all threads are daemons. - */ - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread (r) - t.setDaemon (true) - return t - } - } - } + private[spark] val daemonThreadFactory: ThreadFactory = + new ThreadFactoryBuilder().setDaemon(true).build() /** * Wrapper over newCachedThreadPool. */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = { - var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory (newDaemonThreadFactory) - - return threadPool - } + def newDaemonCachedThreadPool(): ThreadPoolExecutor = + Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Return the string to tell how long has passed in seconds. The passing parameter should be in @@ -322,13 +308,8 @@ private object Utils extends Logging { /** * Wrapper over newFixedThreadPool. */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = + Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Delete a file or directory and its contents recursively. diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 36c01ad629..2ecd14f536 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -52,9 +52,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] - implicit val futureExecContext = ExecutionContext.fromExecutor( - Executors.newCachedThreadPool(DaemonThreadFactory)) - + implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) + var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null serverChannel.configureBlocking(false) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..87f8474ea0 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon with Logging { var attemptId = new AtomicInteger(0) - var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get var listener: TaskSchedulerListener = null diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala index 290fab1ce0..04e6b69b7b 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.{DaemonThread, Logging} +import spark.Logging import spark.storage.StorageLevel import spark.streaming.StreamingContext @@ -48,7 +48,8 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) val queue = new ArrayBlockingQueue[ByteBuffer](2) - blockPushingThread = new DaemonThread { + blockPushingThread = new Thread { + setDaemon(true) override def run() { var nextBlockNumber = 0 while (true) { -- cgit v1.2.3 From e353886a8ca6179f25b4176d7a62b5d04ce79276 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 22 Jan 2013 00:23:31 -0800 Subject: Use generation numbers for fetch failure tracking --- .../main/scala/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 59f2099e91..39a1e6d6c6 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -72,8 +72,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val cacheTracker = env.cacheTracker val mapOutputTracker = env.mapOutputTracker - val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; - // that's not going to be a realistic assumption in general + // For tracking failed nodes, we use the MapOutputTracker's generation number, which is + // sent with every task. When we detect a node failing, we note the current generation number + // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask + // results. + // TODO: Garbage collect information about failure generations when new stages start. + val failedGeneration = new HashMap[String, Long] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now @@ -429,7 +433,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val status = event.result.asInstanceOf[MapStatus] val host = status.address.ip logInfo("ShuffleMapTask finished with host " + host) - if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos + if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) { + logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host) + } else { stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { @@ -495,7 +501,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock // TODO: mark the host as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleHostLost(bmAddress.ip) + handleHostLost(bmAddress.ip, Some(task.generation)) } case other => @@ -507,11 +513,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with /** * Responds to a host being lost. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. + * + * Optionally the generation during which the failure was caught can be passed to avoid allowing + * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleHostLost(host: String) { - if (!deadHosts.contains(host)) { + def handleHostLost(host: String, maybeGeneration: Option[Long] = None) { + val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) + if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { + failedGeneration(host) = currentGeneration logInfo("Host lost: " + host) - deadHosts += host env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { @@ -519,6 +529,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementGeneration() + } cacheTracker.cacheLost(host) updateCacheLocs() } -- cgit v1.2.3 From 7e9ee2e8335f085062d3fdeecd0b49ec63e92117 Mon Sep 17 00:00:00 2001 From: Leemoonsoo Date: Tue, 22 Jan 2013 23:08:34 +0900 Subject: Fix for hanging spark.HttpFileServer with kind of virtual network --- core/src/main/scala/spark/HttpServer.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala index 0196595ba1..4e0507c080 100644 --- a/core/src/main/scala/spark/HttpServer.scala +++ b/core/src/main/scala/spark/HttpServer.scala @@ -4,6 +4,7 @@ import java.io.File import java.net.InetAddress import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.bio.SocketConnector import org.eclipse.jetty.server.handler.DefaultHandler import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.server.handler.ResourceHandler @@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { if (server != null) { throw new ServerStateException("Server is already started") } else { - server = new Server(0) + server = new Server() + val connector = new SocketConnector + connector.setMaxIdleTime(60*1000) + connector.setSoLingerTime(-1) + connector.setPort(0) + server.addConnector(connector) + val threadPool = new QueuedThreadPool threadPool.setDaemon(true) server.setThreadPool(threadPool) -- cgit v1.2.3 From 588b24197a85c4b46a38595007293abef9a41f2c Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 10:19:30 -0600 Subject: Use default arguments instead of constructor overloads. --- core/src/main/scala/spark/SparkContext.scala | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8b6f4b3b7d..495d1b6c78 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -58,27 +58,11 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend class SparkContext( val master: String, val jobName: String, - val sparkHome: String, - val jars: Seq[String], - environment: Map[String, String]) + val sparkHome: String = null, + val jars: Seq[String] = Nil, + environment: Map[String, String] = Map()) extends Logging { - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - * @param sparkHome Location where Spark is installed on cluster nodes. - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - */ - def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) = - this(master, jobName, sparkHome, jars, Map()) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - */ - def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map()) - // Ensure logging is initialized before we spawn any threads initLogging() -- cgit v1.2.3 From 27b3f3f0a980f86bac14a14516b5d52a32aa8cbb Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:30:42 -0600 Subject: Handle slaveLost before slaveIdToHost knows about it. --- .../spark/scheduler/cluster/ClusterScheduler.scala | 31 +++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 20f6e65020..a639b72795 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -252,19 +252,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def slaveLost(slaveId: String, reason: ExecutorLossReason) { var failedHost: Option[String] = None synchronized { - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - logError("Lost an executor on " + host + ": " + reason) - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } else { - // We may get multiple slaveLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor on " + host + " (already removed): " + reason) + slaveIdToHost.get(slaveId) match { + case Some(host) => + if (hostsAlive.contains(host)) { + logError("Lost an executor on " + host + ": " + reason) + slaveIdsWithExecutors -= slaveId + hostsAlive -= host + activeTaskSetsQueue.foreach(_.hostLost(host)) + failedHost = Some(host) + } else { + // We may get multiple slaveLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor on " + host + " (already removed): " + reason) + } + case None => + // We were told about a slave being lost before we could even allocate work to it + logError("Lost slave " + slaveId + " (no work assigned yet)") } } if (failedHost != None) { -- cgit v1.2.3 From 35168d9c89904f0dc0bb470c1799f5ca3b04221f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Jan 2013 17:54:11 -0800 Subject: Fix sys.path bug in PySpark SparkContext.addPyFile --- python/pyspark/context.py | 2 -- python/pyspark/tests.py | 38 +++++++++++++++++++++++++++++++++----- python/pyspark/worker.py | 1 + python/test_support/userlibrary.py | 7 +++++++ 4 files changed, 41 insertions(+), 7 deletions(-) create mode 100755 python/test_support/userlibrary.py diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ec0cc7c2f9..b8d7dc05af 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -215,8 +215,6 @@ class SparkContext(object): """ self.addFile(path) filename = path.split("/")[-1] - os.environ["PYTHONPATH"] = \ - "%s:%s" % (filename, os.environ["PYTHONPATH"]) def setCheckpointDir(self, dirName, useExisting=False): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b0a403b580..4d70ee4f12 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -9,21 +9,32 @@ import time import unittest from pyspark.context import SparkContext +from pyspark.java_gateway import SPARK_HOME -class TestCheckpoint(unittest.TestCase): +class PySparkTestCase(unittest.TestCase): def setUp(self): - self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) - self.checkpointDir = NamedTemporaryFile(delete=False) - os.unlink(self.checkpointDir.name) - self.sc.setCheckpointDir(self.checkpointDir.name) + class_name = self.__class__.__name__ + self.sc = SparkContext('local[4]', class_name , batchSize=2) def tearDown(self): self.sc.stop() # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") + + +class TestCheckpoint(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) + + def tearDown(self): + PySparkTestCase.tearDown(self) shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): @@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase): self.assertEquals([1, 2, 3, 4], recovered.collect()) +class TestAddFile(PySparkTestCase): + + def test_add_py_file(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this job fails due to `userlibrary` not being on the Python path: + def func(x): + from userlibrary import UserClass + return UserClass().hello() + self.assertRaises(Exception, + self.sc.parallelize(range(2)).map(func).first) + # Add the file, so the job should now succeed: + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + res = self.sc.parallelize(range(2)).map(func).first() + self.assertEqual("Hello World!", res) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e7bdb7682b..4bf643da66 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,6 +26,7 @@ def main(): split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir + sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py new file mode 100755 index 0000000000..5bb6f5009f --- /dev/null +++ b/python/test_support/userlibrary.py @@ -0,0 +1,7 @@ +""" +Used to test shipping of code depenencies with SparkContext.addPyFile(). +""" + +class UserClass(object): + def hello(self): + return "Hello World!" -- cgit v1.2.3 From 325297e5c31418f32deeb2a3cc52755094a11cea Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Tue, 22 Jan 2013 17:31:11 -0800 Subject: Add an Avro dependency to REPL to make it compile with Hadoop 2 --- pom.xml | 11 +++++++++++ repl/pom.xml | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/pom.xml b/pom.xml index 483b0f9595..3ea989a082 100644 --- a/pom.xml +++ b/pom.xml @@ -542,6 +542,17 @@ hadoop-client 2.0.0-mr1-cdh${cdh.version}
    + + + org.apache.avro + avro + 1.7.1.cloudera.2 + + + org.apache.avro + avro-ipc + 1.7.1.cloudera.2 + diff --git a/repl/pom.xml b/repl/pom.xml index 2fc9692969..2dc96beaf5 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -175,6 +175,16 @@ hadoop-client provided + + org.apache.avro + avro + provided + + + org.apache.avro + avro-ipc + provided + -- cgit v1.2.3 From 284993100022cc4bd43bf84a0be4dd91cf7a4ac0 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 22 Jan 2013 22:19:30 -0800 Subject: Eliminate CacheTracker. Replaces DAGScheduler's queries of CacheTracker with BlockManagerMaster queries. Adds CacheManager to locally coordinate computation of cached RDDs. --- core/src/main/scala/spark/CacheTracker.scala | 240 --------------------- core/src/main/scala/spark/RDD.scala | 2 +- core/src/main/scala/spark/SparkEnv.scala | 8 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 24 ++- .../main/scala/spark/storage/BlockManager.scala | 24 +-- core/src/test/scala/spark/CacheTrackerSuite.scala | 131 ----------- 6 files changed, 18 insertions(+), 411 deletions(-) delete mode 100644 core/src/main/scala/spark/CacheTracker.scala delete mode 100644 core/src/test/scala/spark/CacheTrackerSuite.scala diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala deleted file mode 100644 index 86ad737583..0000000000 --- a/core/src/main/scala/spark/CacheTracker.scala +++ /dev/null @@ -1,240 +0,0 @@ -package spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ - -import spark.storage.BlockManager -import spark.storage.StorageLevel -import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap} - -private[spark] sealed trait CacheTrackerMessage - -private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage -private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage -private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage -private[spark] case object GetCacheStatus extends CacheTrackerMessage -private[spark] case object GetCacheLocations extends CacheTrackerMessage -private[spark] case object StopCacheTracker extends CacheTrackerMessage - -private[spark] class CacheTrackerActor extends Actor with Logging { - // TODO: Should probably store (String, CacheType) tuples - private val locs = new TimeStampedHashMap[Int, Array[List[String]]] - - /** - * A map from the slave's host name to its cache size. - */ - private val slaveCapacity = new HashMap[String, Long] - private val slaveUsage = new HashMap[String, Long] - - private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues) - - private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) - private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) - private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) - - def receive = { - case SlaveCacheStarted(host: String, size: Long) => - slaveCapacity.put(host, size) - slaveUsage.put(host, 0) - sender ! true - - case RegisterRDD(rddId: Int, numPartitions: Int) => - logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") - locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) - sender ! true - - case AddedToCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) + size) - locs(rddId)(partition) = host :: locs(rddId)(partition) - sender ! true - - case DroppedFromCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) - size) - // Do a sanity check to make sure usage is greater than 0. - locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) - sender ! true - - case MemoryCacheLost(host) => - logInfo("Memory cache lost on " + host) - for ((id, locations) <- locs) { - for (i <- 0 until locations.length) { - locations(i) = locations(i).filterNot(_ == host) - } - } - sender ! true - - case GetCacheLocations => - logInfo("Asked for current cache locations") - sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())} - - case GetCacheStatus => - val status = slaveCapacity.map { case (host, capacity) => - (host, capacity, getCacheUsage(host)) - }.toSeq - sender ! status - - case StopCacheTracker => - logInfo("Stopping CacheTrackerActor") - sender ! true - metadataCleaner.cancel() - context.stop(self) - } -} - -private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) - extends Logging { - - // Tracker actor on the master, or remote reference to it on workers - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "CacheTracker" - - val timeout = 10.seconds - - var trackerActor: ActorRef = if (isMaster) { - val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName) - logInfo("Registered CacheTrackerActor actor") - actor - } else { - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - actorSystem.actorFor(url) - } - - // TODO: Consider removing this HashSet completely as locs CacheTrackerActor already - // keeps track of registered RDDs - val registeredRddIds = new TimeStampedHashSet[Int] - - // Remembers which splits are currently being loaded (on worker nodes) - val loading = new HashSet[String] - - val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues) - - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askTracker(message: Any): Any = { - try { - val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with CacheTracker", e) - } - } - - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askTracker(message) != true) { - throw new SparkException("Error reply received from CacheTracker") - } - } - - // Registers an RDD (on master only) - def registerRDD(rddId: Int, numPartitions: Int) { - registeredRddIds.synchronized { - if (!registeredRddIds.contains(rddId)) { - logInfo("Registering RDD ID " + rddId + " with cache") - registeredRddIds += rddId - communicate(RegisterRDD(rddId, numPartitions)) - } - } - } - - // For BlockManager.scala only - def cacheLost(host: String) { - communicate(MemoryCacheLost(host)) - logInfo("CacheTracker successfully removed entries on " + host) - } - - // Get the usage status of slave caches. Each tuple in the returned sequence - // is in the form of (host name, capacity, usage). - def getCacheStatus(): Seq[(String, Long, Long)] = { - askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]] - } - - // For BlockManager.scala only - def notifyFromBlockManager(t: AddedToCache) { - communicate(t) - } - - // Get a snapshot of the currently known locations - def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { - askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] - } - - // Gets or computes an RDD split - def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) - logInfo("Cache key is " + key) - blockManager.get(key) match { - case Some(cachedValues) => - // Split is in cache, so just return its values - logInfo("Found partition in cache!") - return cachedValues.asInstanceOf[Iterator[T]] - - case None => - // Mark the split as loading (unless someone else marks it first) - loading.synchronized { - if (loading.contains(key)) { - logInfo("Loading contains " + key + ", waiting...") - while (loading.contains(key)) { - try {loading.wait()} catch {case _ =>} - } - logInfo("Loading no longer contains " + key + ", so returning cached result") - // See whether someone else has successfully loaded it. The main way this would fail - // is for the RDD-level cache eviction policy if someone else has loaded the same RDD - // partition but we didn't want to make space for it. However, that case is unlikely - // because it's unlikely that two threads would work on the same RDD partition. One - // downside of the current code is that threads wait serially if this does happen. - blockManager.get(key) match { - case Some(values) => - return values.asInstanceOf[Iterator[T]] - case None => - logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } - } - try { - // If we got here, we have to load the split - val elements = new ArrayBuffer[Any] - logInfo("Computing partition " + split) - elements ++= rdd.compute(split, context) - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) - return elements.iterator.asInstanceOf[Iterator[T]] - } finally { - loading.synchronized { - loading.remove(key) - loading.notifyAll() - } - } - } - } - - // Called by the Cache to report that an entry has been dropped from it - def dropEntry(rddId: Int, partition: Int) { - communicate(DroppedFromCache(rddId, partition, Utils.localHostName())) - } - - def stop() { - communicate(StopCacheTracker) - registeredRddIds.clear() - trackerActor = null - } -} diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e0d2eabb1d..c79f34342f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -176,7 +176,7 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed) { checkpointData.get.iterator(split, context) } else if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel) + SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) } else { compute(split, context) } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 41441720a7..a080194980 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -22,7 +22,7 @@ class SparkEnv ( val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, - val cacheTracker: CacheTracker, + val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, val broadcastManager: BroadcastManager, @@ -39,7 +39,6 @@ class SparkEnv ( def stop() { httpFileServer.stop() mapOutputTracker.stop() - cacheTracker.stop() shuffleFetcher.stop() broadcastManager.stop() blockManager.stop() @@ -100,8 +99,7 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClass[Serializer]( "spark.closure.serializer", "spark.JavaSerializer") - val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) - blockManager.cacheTracker = cacheTracker + val cacheManager = new CacheManager(blockManager) val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) @@ -122,7 +120,7 @@ object SparkEnv extends Logging { actorSystem, serializer, closureSerializer, - cacheTracker, + cacheManager, mapOutputTracker, shuffleFetcher, broadcastManager, diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 59f2099e91..03d173ac3b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -69,8 +69,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with var cacheLocs = new HashMap[Int, Array[List[String]]] val env = SparkEnv.get - val cacheTracker = env.cacheTracker val mapOutputTracker = env.mapOutputTracker + val blockManagerMaster = env.blockManager.master val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; // that's not going to be a realistic assumption in general @@ -95,11 +95,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with }.start() def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + if (!cacheLocs.contains(rdd.id)) { + val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { + locations => locations.map(_.ip).toList + }.toArray + } cacheLocs(rdd.id) } - def updateCacheLocs() { - cacheLocs = cacheTracker.getLocationsSnapshot() + def clearCacheLocs() { + cacheLocs.clear } /** @@ -126,7 +132,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of splits is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - cacheTracker.registerRDD(rdd.id, rdd.splits.size) if (shuffleDep != None) { mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) } @@ -148,8 +153,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visited += r // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")") - cacheTracker.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => @@ -250,7 +253,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val runId = nextRunId.getAndIncrement() val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) - updateCacheLocs() + clearCacheLocs() logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -293,7 +296,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // on the failed node. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { logInfo("Resubmitting failed stages") - updateCacheLocs() + clearCacheLocs() val failed2 = failed.toArray failed.clear() for (stage <- failed2.sortBy(_.priority)) { @@ -443,7 +446,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.shuffleDep.get.shuffleId, stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) } - updateCacheLocs() + clearCacheLocs() if (stage.outputLocs.count(_ == Nil) != 0) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this @@ -519,8 +522,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } - cacheTracker.cacheLost(host) - updateCacheLocs() + clearCacheLocs() } } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a8ac10cdd..e049565f48 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -16,7 +16,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} +import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils} import spark.network._ import spark.serializer.Serializer import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} @@ -71,9 +71,6 @@ class BlockManager( val connectionManagerId = connectionManager.id val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) - // TODO: This will be removed after cacheTracker is removed from the code base. - var cacheTracker: CacheTracker = null - // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) val maxBytesInFlight = @@ -662,10 +659,6 @@ class BlockManager( BlockManager.dispose(bytesAfterPut) - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) return size @@ -733,11 +726,6 @@ class BlockManager( } } - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } - // If replication had started, then wait for it to finish if (level.replication > 1) { if (replicationFuture == null) { @@ -780,16 +768,6 @@ class BlockManager( } } - // TODO: This code will be removed when CacheTracker is gone. - private def notifyCacheTracker(key: String) { - if (cacheTracker != null) { - val rddInfo = key.split("_") - val rddId: Int = rddInfo(1).toInt - val partition: Int = rddInfo(2).toInt - cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host)) - } - } - /** * Read a block consisting of a single object. */ diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala deleted file mode 100644 index 467605981b..0000000000 --- a/core/src/test/scala/spark/CacheTrackerSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -package spark - -import org.scalatest.FunSuite - -import scala.collection.mutable.HashMap - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ - -class CacheTrackerSuite extends FunSuite { - // Send a message to an actor and wait for a reply, in a blocking manner - private def ask(actor: ActorRef, message: Any): Any = { - try { - val timeout = 10.seconds - val future = actor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with actor", e) - } - } - - test("CacheTrackerActor slave initialization & cache status") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("RegisterRDD") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 3)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("AddedToCache") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 2)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) - assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) - assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) - - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("DroppedFromCache") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 2)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) - assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) - assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - - assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - /** - * Helper function to get cacheLocations from CacheTracker - */ - def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = { - val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] - answer.map { case (i, arr) => (i, arr.toList) } - } -} -- cgit v1.2.3 From 43e9ff959645e533bcfa0a5c31e62e32c7e9d0a6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Jan 2013 22:47:26 -0800 Subject: Add test for driver hanging on exit (SPARK-530). --- core/src/test/scala/spark/DriverSuite.scala | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 core/src/test/scala/spark/DriverSuite.scala diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala new file mode 100644 index 0000000000..70a7c8bc2f --- /dev/null +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -0,0 +1,31 @@ +package spark + +import java.io.File + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.scalatest.time.SpanSugar._ + +class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { + // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" + val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + forAll(masters) { (master: String) => + failAfter(10 seconds) { + Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) + } + } + } +} + +/** + * Program that creates a Spark driver but doesn't call SparkContext.stop() or + * Sys.exit() after finishing. + */ +object DriverWithoutCleanup { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "DriverWithoutCleanup") + sc.parallelize(1 to 100, 4).count() + } +} \ No newline at end of file -- cgit v1.2.3 From bacade6caf7527737dc6f02b1c2ca9114e02d8bc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 22:55:26 -0800 Subject: Modified BlockManagerId API to ensure zero duplicate objects. Fixed BlockManagerId testcase in BlockManagerTestSuite. --- .../src/main/scala/spark/scheduler/MapStatus.scala | 2 +- .../main/scala/spark/storage/BlockManager.scala | 2 +- .../main/scala/spark/storage/BlockManagerId.scala | 33 ++++++++++++++++++---- .../scala/spark/storage/BlockManagerMessages.scala | 3 +- .../test/scala/spark/MapOutputTrackerSuite.scala | 22 +++++++-------- .../scala/spark/storage/BlockManagerSuite.scala | 18 ++++++------ 6 files changed, 51 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index 4532d9497f..fae643f3a8 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -20,7 +20,7 @@ private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: } def readExternal(in: ObjectInput) { - address = new BlockManagerId(in) + address = BlockManagerId(in) compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a8ac10cdd..596a69c583 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -69,7 +69,7 @@ class BlockManager( implicit val futureExecContext = connectionManager.futureExecContext val connectionManagerId = connectionManager.id - val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) + val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port) // TODO: This will be removed after cacheTracker is removed from the code base. var cacheTracker: CacheTracker = null diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 488679f049..26c98f2ac8 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -3,20 +3,35 @@ package spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +/** + * This class represent an unique identifier for a BlockManager. + * The first 2 constructors of this class is made private to ensure that + * BlockManagerId objects can be created only using the factory method in + * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects. + * Also, constructor parameters are private to ensure that parameters cannot + * be modified from outside this class. + */ +private[spark] class BlockManagerId private ( + private var ip_ : String, + private var port_ : Int + ) extends Externalizable { + + private def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { def this() = this(null, 0) // For deserialization only - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + def ip = ip_ + + def port = port_ override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) + out.writeUTF(ip_) + out.writeInt(port_) } override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() + ip_ = in.readUTF() + port_ = in.readInt() } @throws(classOf[IOException]) @@ -35,6 +50,12 @@ private[spark] class BlockManagerId(var ip: String, var port: Int) extends Exter private[spark] object BlockManagerId { + def apply(ip: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(ip, port)) + + def apply(in: ObjectInput) = + getCachedBlockManagerId(new BlockManagerId(in)) + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index d73a9b790f..7437fc63eb 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -54,8 +54,7 @@ class UpdateBlockInfo( } override def readExternal(in: ObjectInput) { - blockManagerId = new BlockManagerId() - blockManagerId.readExternal(in) + blockManagerId = BlockManagerId(in) blockId = in.readUTF() storageLevel = new StorageLevel() storageLevel.readExternal(in) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index d3dd3a8fa4..095f415978 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -47,13 +47,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000), - (new BlockManagerId("hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000), + (BlockManagerId("hostB", 1000), size10000))) tracker.stop() } @@ -65,14 +65,14 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the @@ -95,13 +95,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + BlockManagerId("hostA", 1000), Array(compressedSize1000))) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((new BlockManagerId("hostA", 1000), size1000))) + Seq((BlockManagerId("hostA", 1000), size1000))) - masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 8f86e3170e..a33d3324ba 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -82,16 +82,18 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = new StorageLevel(false, false, false, 3) - val id2 = new StorageLevel(false, false, false, 3) + val id1 = BlockManagerId("XXX", 1) + val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 + assert(id2 === id1, "id2 is not same as id1") + assert(id2.eq(id1), "id2 is not the same object as id1") val bytes1 = spark.Utils.serialize(id1) - val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) val bytes2 = spark.Utils.serialize(id2) - val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) - assert(id1_ === id1, "Deserialized id1 not same as original id1") - assert(id2_ === id2, "Deserialized id2 not same as original id1") - assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") - assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2) + assert(id1_ === id1, "Deserialized id1 is not same as original id1") + assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1") + assert(id2_ === id2, "Deserialized id2 is not same as original id2") + assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } test("master + 1 manager interaction") { -- cgit v1.2.3 From 5e11f1e51f17113abb8d3a5bc261af5ba5ffce94 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 23:42:53 -0800 Subject: Modified StorageLevel API to ensure zero duplicate objects. --- .../main/scala/spark/storage/BlockManager.scala | 5 +-- .../main/scala/spark/storage/BlockMessage.scala | 2 +- .../main/scala/spark/storage/StorageLevel.scala | 47 ++++++++++++++-------- .../scala/spark/storage/BlockManagerSuite.scala | 16 +++++--- 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 596a69c583..ca7eb13ec8 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -191,7 +191,7 @@ class BlockManager( case level => val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) - val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L (storageLevel, memSize, diskSize, info.tellMaster) @@ -760,8 +760,7 @@ class BlockManager( */ var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { - val tLevel: StorageLevel = - new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) + val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index 3f234df654..30d7500e01 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -64,7 +64,7 @@ private[spark] class BlockMessage() { val booleanInt = buffer.getInt() val replication = buffer.getInt() - level = new StorageLevel(booleanInt, replication) + level = StorageLevel(booleanInt, replication) val dataLength = buffer.getInt() data = ByteBuffer.allocate(dataLength) diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index e3544e5aae..f2535ae5ae 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -7,25 +7,30 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. + * commonly useful storage levels. The recommended method to create your own storage level + * object is to use `StorageLevel.apply(...)` from the singleton object. */ class StorageLevel( - var useDisk: Boolean, - var useMemory: Boolean, - var deserialized: Boolean, - var replication: Int = 1) + private var useDisk_ : Boolean, + private var useMemory_ : Boolean, + private var deserialized_ : Boolean, + private var replication_ : Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - - assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") - - def this(flags: Int, replication: Int) { + private def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } def this() = this(false, true, false) // For deserialization + def useDisk = useDisk_ + def useMemory = useMemory_ + def deserialized = deserialized_ + def replication = replication_ + + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + override def clone(): StorageLevel = new StorageLevel( this.useDisk, this.useMemory, this.deserialized, this.replication) @@ -43,13 +48,13 @@ class StorageLevel( def toInt: Int = { var ret = 0 - if (useDisk) { + if (useDisk_) { ret |= 4 } - if (useMemory) { + if (useMemory_) { ret |= 2 } - if (deserialized) { + if (deserialized_) { ret |= 1 } return ret @@ -57,15 +62,15 @@ class StorageLevel( override def writeExternal(out: ObjectOutput) { out.writeByte(toInt) - out.writeByte(replication) + out.writeByte(replication_) } override def readExternal(in: ObjectInput) { val flags = in.readByte() - useDisk = (flags & 4) != 0 - useMemory = (flags & 2) != 0 - deserialized = (flags & 1) != 0 - replication = in.readByte() + useDisk_ = (flags & 4) != 0 + useMemory_ = (flags & 2) != 0 + deserialized_ = (flags & 1) != 0 + replication_ = in.readByte() } @throws(classOf[IOException]) @@ -91,6 +96,14 @@ object StorageLevel { val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + /** Create a new StorageLevel object */ + def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) = + getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication)) + + /** Create a new StorageLevel object from its integer representation */ + def apply(flags: Int, replication: Int) = + getCachedStorageLevel(new StorageLevel(flags, replication)) + private[spark] val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index a33d3324ba..a1aeb12f25 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -69,23 +69,29 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("StorageLevel object caching") { - val level1 = new StorageLevel(false, false, false, 3) - val level2 = new StorageLevel(false, false, false, 3) + val level1 = StorageLevel(false, false, false, 3) + val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1 + val level3 = StorageLevel(false, false, false, 2) // this should return a different object + assert(level2 === level1, "level2 is not same as level1") + assert(level2.eq(level1), "level2 is not the same object as level1") + assert(level3 != level1, "level3 is same as level1") val bytes1 = spark.Utils.serialize(level1) val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) val bytes2 = spark.Utils.serialize(level2) val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) assert(level1_ === level1, "Deserialized level1 not same as original level1") - assert(level2_ === level2, "Deserialized level2 not same as original level1") - assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") - assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2") + assert(level2_ === level2, "Deserialized level2 not same as original level2") + assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1") } test("BlockManagerId object caching") { val id1 = BlockManagerId("XXX", 1) val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 + val id3 = BlockManagerId("XXX", 2) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") + assert(id3 != id1, "id3 is same as id1") val bytes1 = spark.Utils.serialize(id1) val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) val bytes2 = spark.Utils.serialize(id2) -- cgit v1.2.3 From 155f31398dc83ecb88b4b3e07849a2a8a0a6592f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 01:10:26 -0800 Subject: Made StorageLevel constructor private, and added StorageLevels.create() to the Java API. Updates scala and java programming guides. --- core/src/main/scala/spark/api/java/StorageLevels.java | 11 +++++++++++ core/src/main/scala/spark/storage/StorageLevel.scala | 6 +++--- docs/java-programming-guide.md | 3 ++- docs/scala-programming-guide.md | 3 ++- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java index 722af3c06c..5e5845ac3a 100644 --- a/core/src/main/scala/spark/api/java/StorageLevels.java +++ b/core/src/main/scala/spark/api/java/StorageLevels.java @@ -17,4 +17,15 @@ public class StorageLevels { public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); + + /** + * Create a new StorageLevel object. + * @param useDisk saved to disk, if true + * @param useMemory saved to memory, if true + * @param deserialized saved as deserialized objects, if true + * @param replication replication factor + */ + public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) { + return StorageLevel.apply(useDisk, useMemory, deserialized, replication); + } } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index f2535ae5ae..45d6ea2656 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -7,10 +7,10 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. The recommended method to create your own storage level - * object is to use `StorageLevel.apply(...)` from the singleton object. + * commonly useful storage levels. To create your own storage level object, use the factor method + * of the singleton object (`StorageLevel(...)`). */ -class StorageLevel( +class StorageLevel private( private var useDisk_ : Boolean, private var useMemory_ : Boolean, private var deserialized_ : Boolean, diff --git a/docs/java-programming-guide.md b/docs/java-programming-guide.md index 188ca4995e..37a906ea1c 100644 --- a/docs/java-programming-guide.md +++ b/docs/java-programming-guide.md @@ -75,7 +75,8 @@ class has a single abstract method, `call()`, that must be implemented. ## Storage Levels RDD [storage level](scala-programming-guide.html#rdd-persistence) constants, such as `MEMORY_AND_DISK`, are -declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. +declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. To +define your own storage level, you can use StorageLevels.create(...). # Other Features diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 7350eca837..301b330a79 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -301,7 +301,8 @@ We recommend going through the following process to select one: * Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve requests from a web application). *All* the storage levels provide full fault tolerance by recomputing lost data, but the replicated ones let you continue running tasks on the RDD without waiting to recompute a lost partition. - + +If you want to define your own storage level (say, with replication factor of 3 instead of 2), then use the function factor method `apply()` of the [`StorageLevel`](api/core/index.html#spark.storage.StorageLevel$) singleton object. # Shared Variables -- cgit v1.2.3 From 9a27062260490336a3bfa97c6efd39b1e7e81573 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:34:44 -0800 Subject: Force generation increment after shuffle map stage --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 39a1e6d6c6..d8a9049e81 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -445,9 +445,16 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logInfo("waiting: " + waiting) logInfo("failed: " + failed) if (stage.shuffleDep != None) { + // We supply true to increment the generation number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // generation incremented to refetch them. + // TODO: Only increment the generation number if this is not the first time + // we registered these map outputs. mapOutputTracker.registerMapOutputs( stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + true) } updateCacheLocs() if (stage.outputLocs.count(_ == Nil) != 0) { -- cgit v1.2.3 From d209b6b7641059610f734414ea05e0494b5510b0 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:35:14 -0800 Subject: Extra debugging from hostLost() --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index d8a9049e81..740aec2e61 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -528,7 +528,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { failedGeneration(host) = currentGeneration - logInfo("Host lost: " + host) + logInfo("Host lost: " + host + " (generation " + currentGeneration + ")") env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { @@ -541,6 +541,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } cacheTracker.cacheLost(host) updateCacheLocs() + } else { + logDebug("Additional host lost message for " + host + + "(generation " + currentGeneration + ")") } } -- cgit v1.2.3 From 0b506dd2ecec909cd514143389d0846db2d194ed Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:37:51 -0800 Subject: Add tests of various node failure scenarios. --- core/src/test/scala/spark/DistributedSuite.scala | 72 ++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index cacc2796b6..0d6b265e54 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -188,4 +188,76 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect() assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE")) } + + test("recover from node failures") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(Seq(true, true), 2) + val singleton = sc.parallelize(Seq(true), 1) + assert(data.count === 2) // force executors to start + val masterId = SparkEnv.get.blockManager.blockManagerId + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).collect.size === 2) + } + + test("recover from repeated node failures during shuffle-map") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false), 2) + val singleton = sc.parallelize(Seq(false), 1) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) + } + } + + test("recover from repeated node failures during shuffle-reduce") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, true), 2) + val singleton = sc.parallelize(Seq(false), 1) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + // This relies on mergeCombiners being used to perform the actual reduce for this + // test to actually be testing what it claims. + val grouped = data.map(x => x -> x).combineByKey( + x => x, + (x: Boolean, y: Boolean) => x, + (x: Boolean, y: Boolean) => failOnMarkedIdentity(x) + ) + assert(grouped.collect.size === 1) + } + } +} + +object DistributedSuite { + // Indicates whether this JVM is marked for failure. + var mark = false + + // Set by test to remember if we are in the driver program so we can assert + // that we are not. + var amMaster = false + + // Act like an identity function, but if the argument is true, set mark to true. + def markNodeIfIdentity(item: Boolean): Boolean = { + if (item) { + assert(!amMaster) + mark = true + } + item + } + + // Act like an identity function, but if mark was set to true previously, fail, + // crashing the entire JVM. + def failOnMarkedIdentity(item: Boolean): Boolean = { + if (mark) { + System.exit(42) + } + item + } } -- cgit v1.2.3 From 79d55700ce2559051ac61cc2fb72a67fd7035926 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 01:57:09 -0800 Subject: One more fix. Made even default constructor of BlockManagerId private to prevent such problems in the future. --- core/src/main/scala/spark/storage/BlockManagerId.scala | 11 ++++++----- core/src/main/scala/spark/storage/BlockManagerMessages.scala | 3 +-- core/src/main/scala/spark/storage/StorageLevel.scala | 7 +++++++ 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 26c98f2ac8..abb8b45a1f 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -16,9 +16,7 @@ private[spark] class BlockManagerId private ( private var port_ : Int ) extends Externalizable { - private def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) - - def this() = this(null, 0) // For deserialization only + private def this() = this(null, 0) // For deserialization only def ip = ip_ @@ -53,8 +51,11 @@ private[spark] object BlockManagerId { def apply(ip: String, port: Int) = getCachedBlockManagerId(new BlockManagerId(ip, port)) - def apply(in: ObjectInput) = - getCachedBlockManagerId(new BlockManagerId(in)) + def apply(in: ObjectInput) = { + val obj = new BlockManagerId() + obj.readExternal(in) + getCachedBlockManagerId(obj) + } val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index 7437fc63eb..30483b0b37 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -56,8 +56,7 @@ class UpdateBlockInfo( override def readExternal(in: ObjectInput) { blockManagerId = BlockManagerId(in) blockId = in.readUTF() - storageLevel = new StorageLevel() - storageLevel.readExternal(in) + storageLevel = StorageLevel(in) memSize = in.readInt() diskSize = in.readInt() } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index 45d6ea2656..d1d1c61c1c 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -104,6 +104,13 @@ object StorageLevel { def apply(flags: Int, replication: Int) = getCachedStorageLevel(new StorageLevel(flags, replication)) + /** Read StorageLevel object from ObjectInput stream */ + def apply(in: ObjectInput) = { + val obj = new StorageLevel() + obj.readExternal(in) + getCachedStorageLevel(obj) + } + private[spark] val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() -- cgit v1.2.3 From ae2ed2947d43860c74a8d40767e289ca78073977 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 23 Jan 2013 10:36:18 -0800 Subject: Allow PySpark's SparkFiles to be used from driver Fix minor documentation formatting issues. --- core/src/main/scala/spark/SparkFiles.java | 8 ++++---- python/pyspark/context.py | 27 +++++++++++++++++++++------ python/pyspark/files.py | 20 +++++++++++++++++--- python/pyspark/tests.py | 23 +++++++++++++++++++++++ python/pyspark/worker.py | 1 + python/test_support/hello.txt | 1 + 6 files changed, 67 insertions(+), 13 deletions(-) create mode 100755 python/test_support/hello.txt diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java index b59d8ce93f..566aec622c 100644 --- a/core/src/main/scala/spark/SparkFiles.java +++ b/core/src/main/scala/spark/SparkFiles.java @@ -3,23 +3,23 @@ package spark; import java.io.File; /** - * Resolves paths to files added through `addFile(). + * Resolves paths to files added through `SparkContext.addFile()`. */ public class SparkFiles { private SparkFiles() {} /** - * Get the absolute path of a file added through `addFile()`. + * Get the absolute path of a file added through `SparkContext.addFile()`. */ public static String get(String filename) { return new File(getRootDirectory(), filename).getAbsolutePath(); } /** - * Get the root directory that contains files added through `addFile()`. + * Get the root directory that contains files added through `SparkContext.addFile()`. */ public static String getRootDirectory() { return SparkEnv.get().sparkFilesDir(); } -} \ No newline at end of file +} diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b8d7dc05af..3e33776af0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,12 +1,15 @@ import os import atexit import shutil +import sys import tempfile +from threading import Lock from tempfile import NamedTemporaryFile from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast +from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD @@ -27,6 +30,8 @@ class SparkContext(object): _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition _next_accum_id = 0 + _active_spark_context = None + _lock = Lock() def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -46,6 +51,11 @@ class SparkContext(object): Java object. Set 1 to disable batching or -1 to use an unlimited batch size. """ + with SparkContext._lock: + if SparkContext._active_spark_context: + raise ValueError("Cannot run multiple SparkContexts at once") + else: + SparkContext._active_spark_context = self self.master = master self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J @@ -75,6 +85,8 @@ class SparkContext(object): # Deploy any code dependencies specified in the constructor for path in (pyFiles or []): self.addPyFile(path) + SparkFiles._sc = self + sys.path.append(SparkFiles.getRootDirectory()) @property def defaultParallelism(self): @@ -85,17 +97,20 @@ class SparkContext(object): return self._jsc.sc().defaultParallelism() def __del__(self): - if self._jsc: - self._jsc.stop() - if self._accumulatorServer: - self._accumulatorServer.shutdown() + self.stop() def stop(self): """ Shut down the SparkContext. """ - self._jsc.stop() - self._jsc = None + if self._jsc: + self._jsc.stop() + self._jsc = None + if self._accumulatorServer: + self._accumulatorServer.shutdown() + self._accumulatorServer = None + with SparkContext._lock: + SparkContext._active_spark_context = None def parallelize(self, c, numSlices=None): """ diff --git a/python/pyspark/files.py b/python/pyspark/files.py index de1334f046..98f6a399cc 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -4,13 +4,15 @@ import os class SparkFiles(object): """ Resolves paths to files added through - L{addFile()}. + L{SparkContext.addFile()}. SparkFiles contains only classmethods; users should not create SparkFiles instances. """ _root_directory = None + _is_running_on_worker = False + _sc = None def __init__(self): raise NotImplementedError("Do not construct SparkFiles objects") @@ -18,7 +20,19 @@ class SparkFiles(object): @classmethod def get(cls, filename): """ - Get the absolute path of a file added through C{addFile()}. + Get the absolute path of a file added through C{SparkContext.addFile()}. """ - path = os.path.join(SparkFiles._root_directory, filename) + path = os.path.join(SparkFiles.getRootDirectory(), filename) return os.path.abspath(path) + + @classmethod + def getRootDirectory(cls): + """ + Get the root directory that contains files added through + C{SparkContext.addFile()}. + """ + if cls._is_running_on_worker: + return cls._root_directory + else: + # This will have to change if we support multiple SparkContexts: + return cls._sc.jvm.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4d70ee4f12..46ab34f063 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -4,22 +4,26 @@ individual modules. """ import os import shutil +import sys from tempfile import NamedTemporaryFile import time import unittest from pyspark.context import SparkContext +from pyspark.files import SparkFiles from pyspark.java_gateway import SPARK_HOME class PySparkTestCase(unittest.TestCase): def setUp(self): + self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ self.sc = SparkContext('local[4]', class_name , batchSize=2) def tearDown(self): self.sc.stop() + sys.path = self._old_sys_path # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") @@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase): res = self.sc.parallelize(range(2)).map(func).first() self.assertEqual("Hello World!", res) + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEquals("Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4bf643da66..d33d6dd15f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,6 +26,7 @@ def main(): split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt new file mode 100755 index 0000000000..980a0d5f19 --- /dev/null +++ b/python/test_support/hello.txt @@ -0,0 +1 @@ +Hello World! -- cgit v1.2.3 From e1027ca6398fd5b1a99a2203df840911c4dccb27 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:22:11 -0800 Subject: Actually add CacheManager. --- core/src/main/scala/spark/CacheManager.scala | 65 ++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 core/src/main/scala/spark/CacheManager.scala diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala new file mode 100644 index 0000000000..a0b53fd9d6 --- /dev/null +++ b/core/src/main/scala/spark/CacheManager.scala @@ -0,0 +1,65 @@ +package spark + +import scala.collection.mutable.{ArrayBuffer, HashSet} +import spark.storage.{BlockManager, StorageLevel} + + +/** Spark class responsible for passing RDDs split contents to the BlockManager and making + sure a node doesn't load two copies of an RDD at once. + */ +private[spark] class CacheManager(blockManager: BlockManager) extends Logging { + private val loading = new HashSet[String] + + /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */ + def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) + : Iterator[T] = { + val key = "rdd_%d_%d".format(rdd.id, split.index) + logInfo("Cache key is " + key) + blockManager.get(key) match { + case Some(cachedValues) => + // Split is in cache, so just return its values + logInfo("Found partition in cache!") + return cachedValues.asInstanceOf[Iterator[T]] + + case None => + // Mark the split as loading (unless someone else marks it first) + loading.synchronized { + if (loading.contains(key)) { + logInfo("Loading contains " + key + ", waiting...") + while (loading.contains(key)) { + try {loading.wait()} catch {case _ =>} + } + logInfo("Loading no longer contains " + key + ", so returning cached result") + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + blockManager.get(key) match { + case Some(values) => + return values.asInstanceOf[Iterator[T]] + case None => + logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") + loading.add(key) + } + } else { + loading.add(key) + } + } + try { + // If we got here, we have to load the split + val elements = new ArrayBuffer[Any] + logInfo("Computing partition " + split) + elements ++= rdd.compute(split, context) + // Try to put this block in the blockManager + blockManager.put(key, elements, storageLevel, true) + return elements.iterator.asInstanceOf[Iterator[T]] + } finally { + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } + } + } + } +} -- cgit v1.2.3 From 88b9d240fda7ca34c08752dfa66797eecb6db872 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:40:38 -0800 Subject: Remove dead code in test. --- core/src/test/scala/spark/DistributedSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 0d6b265e54..af66d33aa3 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -194,7 +194,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter DistributedSuite.amMaster = true sc = new SparkContext(clusterUrl, "test") val data = sc.parallelize(Seq(true, true), 2) - val singleton = sc.parallelize(Seq(true), 1) assert(data.count === 2) // force executors to start val masterId = SparkEnv.get.blockManager.blockManagerId assert(data.map(markNodeIfIdentity).collect.size === 2) @@ -207,7 +206,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, false), 2) - val singleton = sc.parallelize(Seq(false), 1) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) -- cgit v1.2.3 From be4a115a7ec7fb6ec0d34f1a1a1bb2c9bbe7600e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:48:45 -0800 Subject: Clarify TODO. --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 740aec2e61..14a3ef8ad7 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -76,7 +76,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // sent with every task. When we detect a node failing, we note the current generation number // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask // results. - // TODO: Garbage collect information about failure generations when new stages start. + // TODO: Garbage collect information about failure generations when we know there are no more + // stray messages to detect. val failedGeneration = new HashMap[String, Long] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done -- cgit v1.2.3 From e1985bfa04ad4583ac1f0f421cbe0182ce7c53df Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 Jan 2013 16:21:14 -0800 Subject: be sure to set class loader of kryo instances --- core/src/main/scala/spark/KryoSerializer.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 93d7327324..56919544e8 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -206,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { kryo } - def newInstance(): SerializerInstance = new KryoSerializerInstance(this) + def newInstance(): SerializerInstance = { + this.kryo.setClassLoader(Thread.currentThread().getContextClassLoader) + new KryoSerializerInstance(this) + } } -- cgit v1.2.3 From 5c7422292ecace947f78e5ebe97e83a355531af7 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:59:51 -0800 Subject: Remove more dead code from test. --- core/src/test/scala/spark/DistributedSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index af66d33aa3..0487e06d12 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -218,7 +218,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, true), 2) - val singleton = sc.parallelize(Seq(false), 1) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) // This relies on mergeCombiners being used to perform the actual reduce for this -- cgit v1.2.3 From 1dd82743e09789f8fdae2f5628545c0cb9f79245 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 23 Jan 2013 13:07:27 -0800 Subject: Fix compile error due to cherry-pick --- core/src/main/scala/spark/KryoSerializer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 56919544e8..0bd73e936b 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -207,7 +207,7 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { } def newInstance(): SerializerInstance = { - this.kryo.setClassLoader(Thread.currentThread().getContextClassLoader) + this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader) new KryoSerializerInstance(this) } } -- cgit v1.2.3