aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--bagel/src/test/resources/log4j.properties4
-rwxr-xr-xconf/streaming-env.sh.template22
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala20
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala205
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala4
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala42
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala16
-rw-r--r--core/src/main/scala/spark/Partitioner.scala4
-rw-r--r--core/src/main/scala/spark/RDD.scala250
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala105
-rw-r--r--core/src/main/scala/spark/SparkContext.scala52
-rw-r--r--core/src/main/scala/spark/Utils.scala60
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala34
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala17
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala43
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala128
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala40
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala33
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala9
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala11
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala13
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala11
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala17
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala17
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala43
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala57
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala94
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala19
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala42
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashMap.scala7
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashSet.scala5
-rw-r--r--core/src/test/resources/log4j.properties4
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala398
-rw-r--r--core/src/test/scala/spark/ClosureCleanerSuite.scala2
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java11
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala21
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala33
-rwxr-xr-xdocs/_layouts/global.html12
-rw-r--r--docs/_plugins/copy_api_dirs.rb4
-rw-r--r--docs/api.md5
-rw-r--r--docs/configuration.md11
-rw-r--r--docs/streaming-programming-guide.md193
-rw-r--r--docs/tuning.md30
-rw-r--r--examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala (renamed from streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala)2
-rw-r--r--examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala36
-rw-r--r--examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala (renamed from streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala)0
-rw-r--r--examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala (renamed from streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala)17
-rw-r--r--examples/src/main/scala/spark/streaming/examples/QueueStream.scala (renamed from streaming/src/main/scala/spark/streaming/examples/QueueStream.scala)0
-rw-r--r--examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala46
-rw-r--r--examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala (renamed from streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala)0
-rw-r--r--examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala (renamed from streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala)2
-rw-r--r--pom.xml2
-rw-r--r--project/SparkBuild.scala4
-rw-r--r--repl/src/test/resources/log4j.properties4
-rwxr-xr-xrun4
-rw-r--r--sentences.txt3
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala5
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala517
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala1
-rw-r--r--streaming/src/main/scala/spark/streaming/DataHandler.scala83
-rw-r--r--streaming/src/main/scala/spark/streaming/Interval.scala1
-rw-r--r--streaming/src/main/scala/spark/streaming/Job.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/JobManager.scala3
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala10
-rw-r--r--streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala11
-rw-r--r--streaming/src/main/scala/spark/streaming/Scheduler.scala12
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala200
-rw-r--r--streaming/src/main/scala/spark/streaming/Time.scala37
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala (renamed from streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala)4
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala (renamed from streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala)3
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala (renamed from streaming/src/main/scala/spark/streaming/FileInputDStream.scala)21
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala21
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala20
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala20
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala (renamed from streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala)38
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala28
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala17
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala19
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala (renamed from streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala)33
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala21
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala21
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala20
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala (renamed from streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala)90
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala (renamed from streaming/src/main/scala/spark/streaming/QueueInputDStream.scala)3
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala (renamed from streaming/src/main/scala/spark/streaming/RawInputDStream.scala)11
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala (renamed from streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala)6
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala27
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala (renamed from streaming/src/main/scala/spark/streaming/SocketInputDStream.scala)21
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala (renamed from streaming/src/main/scala/spark/streaming/StateDStream.scala)6
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala19
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala40
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala (renamed from streaming/src/main/scala/spark/streaming/WindowedDStream.scala)5
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/FileStream.scala46
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala75
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala33
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala49
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala25
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala43
-rw-r--r--streaming/src/main/scala/spark/streaming/util/Clock.scala8
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala2
-rw-r--r--streaming/src/test/resources/log4j.properties13
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala2
-rw-r--r--streaming/src/test/scala/spark/streaming/FailureSuite.scala2
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala3
-rw-r--r--streaming/src/test/scala/spark/streaming/TestSuiteBase.scala50
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala12
113 files changed, 2422 insertions, 1673 deletions
diff --git a/.gitignore b/.gitignore
index c207409e3c..88d7b56181 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,6 +12,7 @@ third_party/libmesos.so
third_party/libmesos.dylib
conf/java-opts
conf/spark-env.sh
+conf/streaming-env.sh
conf/log4j.properties
docs/_site
docs/api
@@ -31,4 +32,5 @@ project/plugins/src_managed/
logs/
log/
spark-tests.log
+streaming-tests.log
dependency-reduced-pom.xml
diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties
index 4c99e450bc..83d05cab2f 100644
--- a/bagel/src/test/resources/log4j.properties
+++ b/bagel/src/test/resources/log4j.properties
@@ -1,8 +1,8 @@
-# Set everything to be logged to the console
+# Set everything to be logged to the file bagel/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-log4j.appender.file.file=spark-tests.log
+log4j.appender.file.file=bagel/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
diff --git a/conf/streaming-env.sh.template b/conf/streaming-env.sh.template
deleted file mode 100755
index 1ea9ba5541..0000000000
--- a/conf/streaming-env.sh.template
+++ /dev/null
@@ -1,22 +0,0 @@
-#!/usr/bin/env bash
-
-# This file contains a few additional setting that are useful for
-# running streaming jobs in Spark. Copy this file as streaming-env.sh .
-# Note that this shell script will be read after spark-env.sh, so settings
-# in this file may override similar settings (if present) in spark-env.sh .
-
-
-# Using concurrent GC is strongly recommended as it can significantly
-# reduce GC related pauses.
-
-SPARK_JAVA_OPTS+=" -XX:+UseConcMarkSweepGC"
-
-# Using Kryo serialization can improve serialization performance
-# and therefore the throughput of the Spark Streaming programs. However,
-# using Kryo serialization with custom classes may required you to
-# register the classes with Kryo. Refer to the Spark documentation
-# for more details.
-
-# SPARK_JAVA_OPTS+=" -Dspark.serializer=spark.KryoSerializer"
-
-export SPARK_JAVA_OPTS
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 05428aae1f..86ad737583 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -39,7 +39,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging {
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]
- private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.cleanup)
+ private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues)
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
@@ -120,7 +120,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
// Remembers which splits are currently being loaded (on worker nodes)
val loading = new HashSet[String]
- val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.cleanup)
+ val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
@@ -210,26 +210,20 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
loading.add(key)
}
}
- // If we got here, we have to load the split
- // Tell the master that we're doing so
- //val host = System.getProperty("spark.hostname", Utils.localHostName)
- //val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
- // TODO: fetch any remote copy of the split that may be available
- // TODO: also register a listener for when it unloads
- logInfo("Computing partition " + split)
- val elements = new ArrayBuffer[Any]
- elements ++= rdd.compute(split, context)
try {
+ // If we got here, we have to load the split
+ val elements = new ArrayBuffer[Any]
+ logInfo("Computing partition " + split)
+ elements ++= rdd.compute(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
- //future.apply() // Wait for the reply from the cache tracker
+ return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
- return elements.iterator.asInstanceOf[Iterator[T]]
}
}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 44b630e478..93d7327324 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -9,153 +9,80 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
-import com.esotericsoftware.kryo.serialize.ClassSerializer
-import com.esotericsoftware.kryo.serialize.SerializableSerializer
+import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
+import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
import serializer.{SerializerInstance, DeserializationStream, SerializationStream}
import spark.broadcast._
import spark.storage._
-/**
- * Zig-zag encoder used to write object sizes to serialization streams.
- * Based on Kryo's integer encoder.
- */
-private[spark] object ZigZag {
- def writeInt(n: Int, out: OutputStream) {
- var value = n
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- out.write(value)
- }
+private[spark]
+class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
- def readInt(in: InputStream): Int = {
- var offset = 0
- var result = 0
- while (offset < 32) {
- val b = in.read()
- if (b == -1) {
- throw new EOFException("End of stream")
- }
- result |= ((b & 0x7F) << offset)
- if ((b & 0x80) == 0) {
- return result
- }
- offset += 7
- }
- throw new SparkException("Malformed zigzag-encoded integer")
- }
-}
-
-private[spark]
-class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
-extends SerializationStream {
- val channel = Channels.newChannel(out)
+ val output = new KryoOutput(outStream)
def writeObject[T](t: T): SerializationStream = {
- kryo.writeClassAndObject(threadBuffer, t)
- ZigZag.writeInt(threadBuffer.position(), out)
- threadBuffer.flip()
- channel.write(threadBuffer)
- threadBuffer.clear()
+ kryo.writeClassAndObject(output, t)
this
}
- def flush() { out.flush() }
- def close() { out.close() }
+ def flush() { output.flush() }
+ def close() { output.close() }
}
-private[spark]
-class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
-extends DeserializationStream {
+private[spark]
+class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
+
+ val input = new KryoInput(inStream)
+
def readObject[T](): T = {
- val len = ZigZag.readInt(in)
- objectBuffer.readClassAndObject(in, len).asInstanceOf[T]
+ try {
+ kryo.readClassAndObject(input).asInstanceOf[T]
+ } catch {
+ // DeserializationStream uses the EOF exception to indicate stopping condition.
+ case e: com.esotericsoftware.kryo.KryoException => throw new java.io.EOFException
+ }
}
- def close() { in.close() }
+ def close() {
+ // Kryo's Input automatically closes the input stream it is using.
+ input.close()
+ }
}
private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
- val kryo = ks.kryo
- val threadBuffer = ks.threadBuffer.get()
- val objectBuffer = ks.objectBuffer.get()
+
+ val kryo = ks.kryo.get()
+ val output = ks.output.get()
+ val input = ks.input.get()
def serialize[T](t: T): ByteBuffer = {
- // Write it to our thread-local scratch buffer first to figure out the size, then return a new
- // ByteBuffer of the appropriate size
- threadBuffer.clear()
- kryo.writeClassAndObject(threadBuffer, t)
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
+ output.clear()
+ kryo.writeClassAndObject(output, t)
+ ByteBuffer.wrap(output.toBytes)
}
def deserialize[T](bytes: ByteBuffer): T = {
- kryo.readClassAndObject(bytes).asInstanceOf[T]
+ input.setBuffer(bytes.array)
+ kryo.readClassAndObject(input).asInstanceOf[T]
}
def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
- val obj = kryo.readClassAndObject(bytes).asInstanceOf[T]
+ input.setBuffer(bytes.array)
+ val obj = kryo.readClassAndObject(input).asInstanceOf[T]
kryo.setClassLoader(oldClassLoader)
obj
}
def serializeStream(s: OutputStream): SerializationStream = {
- threadBuffer.clear()
- new KryoSerializationStream(kryo, threadBuffer, s)
+ new KryoSerializationStream(kryo, s)
}
def deserializeStream(s: InputStream): DeserializationStream = {
- new KryoDeserializationStream(objectBuffer, s)
- }
-
- override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
- threadBuffer.clear()
- while (iterator.hasNext) {
- val element = iterator.next()
- // TODO: Do we also want to write the object's size? Doesn't seem necessary.
- kryo.writeClassAndObject(threadBuffer, element)
- }
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
- }
-
- override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
- buffer.rewind()
- new Iterator[Any] {
- override def hasNext: Boolean = buffer.remaining > 0
- override def next(): Any = kryo.readClassAndObject(buffer)
- }
+ new KryoDeserializationStream(kryo, s)
}
}
@@ -171,18 +98,19 @@ trait KryoRegistrator {
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
*/
class KryoSerializer extends spark.serializer.Serializer with Logging {
- // Make this lazy so that it only gets called once we receive our first task on each executor,
- // so we can pull out any custom Kryo registrator from the user's JARs.
- lazy val kryo = createKryo()
- val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
+ val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
- val objectBuffer = new ThreadLocal[ObjectBuffer] {
- override def initialValue = new ObjectBuffer(kryo, bufferSize)
+ val kryo = new ThreadLocal[Kryo] {
+ override def initialValue = createKryo()
}
- val threadBuffer = new ThreadLocal[ByteBuffer] {
- override def initialValue = ByteBuffer.allocate(bufferSize)
+ val output = new ThreadLocal[KryoOutput] {
+ override def initialValue = new KryoOutput(bufferSize)
+ }
+
+ val input = new ThreadLocal[KryoInput] {
+ override def initialValue = new KryoInput(bufferSize)
}
def createKryo(): Kryo = {
@@ -213,41 +141,44 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo.register(obj.getClass)
}
- // Register the following classes for passing closures.
- kryo.register(classOf[Class[_]], new ClassSerializer(kryo))
- kryo.setRegistrationOptional(true)
-
// Allow sending SerializableWritable
- kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer())
- kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer())
+ kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
+ kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
// Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would.
- class SingletonSerializer(obj: AnyRef) extends KSerializer {
- override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {}
- override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = obj.asInstanceOf[T]
+ class SingletonSerializer[T](obj: T) extends KSerializer[T] {
+ override def write(kryo: Kryo, output: KryoOutput, obj: T) {}
+ override def read(kryo: Kryo, input: KryoInput, cls: java.lang.Class[T]): T = obj
}
- kryo.register(None.getClass, new SingletonSerializer(None))
- kryo.register(Nil.getClass, new SingletonSerializer(Nil))
+ kryo.register(None.getClass, new SingletonSerializer[AnyRef](None))
+ kryo.register(Nil.getClass, new SingletonSerializer[AnyRef](Nil))
// Register maps with a special serializer since they have complex internal structure
class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any])
- extends KSerializer {
- override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {
+ extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
+ override def write(
+ kryo: Kryo,
+ output: KryoOutput,
+ obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
val map = obj.asInstanceOf[scala.collection.Map[Any, Any]]
- kryo.writeObject(buf, map.size.asInstanceOf[java.lang.Integer])
+ kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer])
for ((k, v) <- map) {
- kryo.writeClassAndObject(buf, k)
- kryo.writeClassAndObject(buf, v)
+ kryo.writeClassAndObject(output, k)
+ kryo.writeClassAndObject(output, v)
}
}
- override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = {
- val size = kryo.readObject(buf, classOf[java.lang.Integer]).intValue
+ override def read (
+ kryo: Kryo,
+ input: KryoInput,
+ cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
+ : Array[(Any, Any)] => scala.collection.Map[Any, Any] = {
+ val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue
val elems = new Array[(Any, Any)](size)
for (i <- 0 until size)
- elems(i) = (kryo.readClassAndObject(buf), kryo.readClassAndObject(buf))
- buildMap(elems).asInstanceOf[T]
+ elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input))
+ buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]]
}
}
kryo.register(mutable.HashMap().getClass, new ScalaMapSerializer(mutable.HashMap() ++ _))
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 5ebdba0fc8..a2fa2d1ea7 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -178,8 +178,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def cleanup(cleanupTime: Long) {
- mapStatuses.cleanup(cleanupTime)
- cachedSerializedStatuses.cleanup(cleanupTime)
+ mapStatuses.clearOldValues(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index ec6b209932..07ae2d647c 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -23,7 +23,6 @@ import spark.partial.BoundedDouble
import spark.partial.PartialResult
import spark.rdd._
import spark.SparkContext._
-import java.lang.ref.WeakReference
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -53,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ if (getKeyClass().isArray) {
+ if (mapSideCombine) {
+ throw new SparkException("Cannot use map-side combining with array keys.")
+ }
+ if (partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ }
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (mapSideCombine) {
@@ -93,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
+
+ if (getKeyClass().isArray) {
+ throw new SparkException("reduceByKeyLocally() does not support array keys")
+ }
+
def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V]
for ((k, v) <- iter) {
@@ -166,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* be set to true.
*/
def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
+ if (getKeyClass().isArray) {
+ if (mapSideCombine) {
+ throw new SparkException("Cannot use map-side combining with array keys.")
+ }
+ if (partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ }
if (mapSideCombine) {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
@@ -337,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
partitioner)
@@ -353,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]],
other1.asInstanceOf[RDD[(_, _)]],
@@ -439,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
- throw new UnsupportedOperationException("lookup() called on an RDD without a partitioner")
+ self.filter(_._1 == key).map(_._2).collect
}
}
@@ -625,20 +651,20 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
-class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U)
- extends RDD[(K, U)](prev.get) {
+class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U)
+ extends RDD[(K, U)](prev) {
- override def splits = firstParent[(K, V)].splits
+ override def getSplits = firstParent[(K, V)].splits
override val partitioner = firstParent[(K, V)].partitioner
override def compute(split: Split, context: TaskContext) =
firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) }
}
private[spark]
-class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U])
- extends RDD[(K, U)](prev.get) {
+class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
+ extends RDD[(K, U)](prev) {
- override def splits = firstParent[(K, V)].splits
+ override def getSplits = firstParent[(K, V)].splits
override val partitioner = firstParent[(K, V)].partitioner
override def compute(split: Split, context: TaskContext) = {
firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) }
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index 68416a78d0..ede933c9e9 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -30,26 +30,30 @@ private[spark] class ParallelCollection[T: ClassManifest](
extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
- // instead. UPDATE: With the new changes to enable checkpointing, this an be done.
+ // instead.
+ // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
@transient
- val splits_ = {
+ var splits_ : Array[Split] = {
val slices = ParallelCollection.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
}
- override def splits = splits_.asInstanceOf[Array[Split]]
-
+ override def getSplits = splits_.asInstanceOf[Array[Split]]
override def compute(s: Split, context: TaskContext) =
s.asInstanceOf[ParallelCollectionSplit[T]].iterator
- override def preferredLocations(s: Split): Seq[String] = {
- locationPrefs.get(splits_.indexOf(s)) match {
+ override def getPreferredLocations(s: Split): Seq[String] = {
+ locationPrefs.get(s.index) match {
case Some(s) => s
case _ => Nil
}
}
+
+ override def clearDependencies() {
+ splits_ = null
+ }
}
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index b71021a082..9d5b966e1e 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable {
/**
* A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
+ *
+ * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
+ * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
+ * produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index cf3ed06773..0de6f04d50 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -81,78 +81,33 @@ abstract class RDD[T: ClassManifest](
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
- // Methods that must be implemented by subclasses:
-
- /** Set of partitions in this RDD. */
- def splits: Array[Split]
+ // =======================================================================
+ // Methods that should be implemented by subclasses of RDD
+ // =======================================================================
/** Function for computing a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T]
+ /** Set of partitions in this RDD. */
+ protected def getSplits(): Array[Split]
+
/** How this RDD depends on any parent RDDs. */
- def dependencies: List[Dependency[_]] = dependencies_
+ protected def getDependencies(): List[Dependency[_]] = dependencies_
- /** Record user function generating this RDD. */
- private[spark] val origin = Utils.getSparkCallSite
+ /** Optionally overridden by subclasses to specify placement preferences. */
+ protected def getPreferredLocations(split: Split): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
- /** Optionally overridden by subclasses to specify placement preferences. */
- def preferredLocations(split: Split): Seq[String] = {
- if (isCheckpointed) {
- checkpointRDD.preferredLocations(split)
- } else {
- Nil
- }
- }
- /** The [[spark.SparkContext]] that this RDD was created on. */
- def context = sc
-
- private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+ // =======================================================================
+ // Methods and fields available on all RDDs
+ // =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
- // Variables relating to persistence
- private var storageLevel: StorageLevel = StorageLevel.NONE
-
- /** Returns the first parent RDD */
- protected[spark] def firstParent[U: ClassManifest] = dependencies.head.rdd.asInstanceOf[RDD[U]]
-
- /** Returns the `i` th parent RDD */
- protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]]
-
- //////////////////////////////////////////////////////////////////////////////
- // Checkpointing related variables
- //////////////////////////////////////////////////////////////////////////////
-
- // override to set this to false to avoid checkpointing an RDD
- protected val isCheckpointable = true
-
- // set to true when an RDD is marked for checkpointing
- protected var shouldCheckpoint = false
-
- // set to true when checkpointing is in progress
- protected var isCheckpointInProgress = false
-
- // set to true after checkpointing is completed
- protected[spark] var isCheckpointed = false
-
- // set to the checkpoint file after checkpointing is completed
- protected[spark] var checkpointFile: String = null
-
- // set to the HadoopRDD of the checkpoint file
- protected var checkpointRDD: RDD[T] = null
-
- // set to the splits of the Hadoop RDD
- protected var checkpointRDDSplits: Seq[Split] = null
-
- //////////////////////////////////////////////////////////////////////////////
- // Methods available on all RDDs
- //////////////////////////////////////////////////////////////////////////////
-
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
@@ -177,81 +132,39 @@ abstract class RDD[T: ClassManifest](
def getStorageLevel = storageLevel
/**
- * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
- * (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
- * This is used to truncate very long lineages. In the current implementation, Spark will save
- * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
- * Hence, it is strongly recommended to use checkpoint() on RDDs when
- * (i) Checkpoint() is called before the any job has been executed on this RDD.
- * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
- * require recomputation.
+ * Get the preferred location of a split, taking into account whether the
+ * RDD is checkpointed or not.
*/
- protected[spark] def checkpoint() {
- synchronized {
- if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) {
- // do nothing
- } else if (isCheckpointable) {
- if (sc.checkpointDir == null) {
- throw new Exception("Checkpoint directory has not been set in the SparkContext.")
- }
- shouldCheckpoint = true
- } else {
- throw new Exception(this + " cannot be checkpointed")
- }
- }
- }
-
- def getCheckpointData(): Any = {
- synchronized {
- checkpointFile
+ final def preferredLocations(split: Split): Seq[String] = {
+ if (isCheckpointed) {
+ checkpointData.get.getPreferredLocations(split)
+ } else {
+ getPreferredLocations(split)
}
}
/**
- * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after
- * a job using this RDD has completed (therefore the RDD has been materialized and potentially
- * stored in memory). In case this RDD is not marked for checkpointing, doCheckpoint() is called
- * recursively on the parent RDDs.
+ * Get the array of splits of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
*/
- private[spark] def doCheckpoint() {
- val startCheckpoint = synchronized {
- if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) {
- isCheckpointInProgress = true
- true
- } else {
- false
- }
- }
-
- if (startCheckpoint) {
- val rdd = this
- rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString
- rdd.saveAsObjectFile(checkpointFile)
- rdd.synchronized {
- rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size)
- rdd.checkpointRDDSplits = rdd.checkpointRDD.splits
- rdd.changeDependencies(rdd.checkpointRDD)
- rdd.shouldCheckpoint = false
- rdd.isCheckpointInProgress = false
- rdd.isCheckpointed = true
- logInfo("Done checkpointing RDD " + rdd.id + ", " + rdd + ", created RDD " +
- rdd.checkpointRDD.id + ", " + rdd.checkpointRDD)
- }
+ final def splits: Array[Split] = {
+ if (isCheckpointed) {
+ checkpointData.get.getSplits
} else {
- // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked
- dependencies.foreach(_.rdd.doCheckpoint())
+ getSplits
}
}
/**
- * Changes the dependencies of this RDD from its original parents to the new
- * [[spark.rdd.HadoopRDD]] (`newRDD`) created from the checkpoint file. This method must ensure
- * that all references to the original parent RDDs must be removed to enable the parent RDDs to
- * be garbage collected. Subclasses of RDD may override this method for implementing their own
- * changing logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
+ * Get the list of dependencies of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
*/
- protected def changeDependencies(newRDD: RDD[_]) {
- dependencies_ = List(new OneToOneDependency(newRDD))
+ final def dependencies: List[Dependency[_]] = {
+ if (isCheckpointed) {
+ dependencies_
+ } else {
+ getDependencies
+ }
}
/**
@@ -261,8 +174,7 @@ abstract class RDD[T: ClassManifest](
*/
final def iterator(split: Split, context: TaskContext): Iterator[T] = {
if (isCheckpointed) {
- // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original
- checkpointRDD.iterator(checkpointRDDSplits(split.index), context)
+ checkpointData.get.iterator(split, context)
} else if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
} else {
@@ -292,9 +204,11 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int = splits.size): RDD[T] =
+ def distinct(numSplits: Int): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
+ def distinct(): RDD[T] = distinct(splits.size)
+
/**
* Return a sampled subset of this RDD.
*/
@@ -522,6 +436,9 @@ abstract class RDD[T: ClassManifest](
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValue() does not support arrays")
+ }
// TODO: This should perhaps be distributed by default.
def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
val map = new OLMap[T]
@@ -550,6 +467,9 @@ abstract class RDD[T: ClassManifest](
timeout: Long,
confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValueApprox() does not support arrays")
+ }
val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
val map = new OLMap[T]
while (iter.hasNext) {
@@ -614,18 +534,84 @@ abstract class RDD[T: ClassManifest](
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
- @throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
- synchronized {
- oos.defaultWriteObject()
+ /**
+ * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
+ * (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
+ * This is used to truncate very long lineages. In the current implementation, Spark will save
+ * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
+ * Hence, it is strongly recommended to use checkpoint() on RDDs when
+ * (i) checkpoint() is called before the any job has been executed on this RDD.
+ * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
+ * require recomputation.
+ */
+ def checkpoint() {
+ if (checkpointData.isEmpty) {
+ checkpointData = Some(new RDDCheckpointData(this))
+ checkpointData.get.markForCheckpoint()
}
}
- @throws(classOf[IOException])
- private def readObject(ois: ObjectInputStream) {
- synchronized {
- ois.defaultReadObject()
- }
+ /**
+ * Return whether this RDD has been checkpointed or not
+ */
+ def isCheckpointed(): Boolean = {
+ if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false
+ }
+
+ /**
+ * Gets the name of the file to which this RDD was checkpointed
+ */
+ def getCheckpointFile(): Option[String] = {
+ if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None
}
+ // =======================================================================
+ // Other internal methods and fields
+ // =======================================================================
+
+ private var storageLevel: StorageLevel = StorageLevel.NONE
+
+ /** Record user function generating this RDD. */
+ private[spark] val origin = Utils.getSparkCallSite
+
+ private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+
+ private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
+
+ /** Returns the first parent RDD */
+ protected[spark] def firstParent[U: ClassManifest] = {
+ dependencies.head.rdd.asInstanceOf[RDD[U]]
+ }
+
+ /** The [[spark.SparkContext]] that this RDD was created on. */
+ def context = sc
+
+ /**
+ * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler
+ * after a job using this RDD has completed (therefore the RDD has been materialized and
+ * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
+ */
+ protected[spark] def doCheckpoint() {
+ if (checkpointData.isDefined) checkpointData.get.doCheckpoint()
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
+
+ /**
+ * Changes the dependencies of this RDD from its original parents to the new RDD
+ * (`newRDD`) created from the checkpoint file.
+ */
+ protected[spark] def changeDependencies(newRDD: RDD[_]) {
+ clearDependencies()
+ dependencies_ = List(new OneToOneDependency(newRDD))
+ }
+
+ /**
+ * Clears the dependencies of this RDD. This method must ensure that all references
+ * to the original parent RDDs is removed to enable the parent RDDs to be garbage
+ * collected. Subclasses of RDD may override this method for implementing their own cleaning
+ * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
+ */
+ protected[spark] def clearDependencies() {
+ dependencies_ = null
+ }
}
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
new file mode 100644
index 0000000000..d845a522e4
--- /dev/null
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -0,0 +1,105 @@
+package spark
+
+import org.apache.hadoop.fs.Path
+import rdd.{CheckpointRDD, CoalescedRDD}
+import scheduler.{ResultTask, ShuffleMapTask}
+
+/**
+ * Enumeration to manage state transitions of an RDD through checkpointing
+ * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
+ */
+private[spark] object CheckpointState extends Enumeration {
+ type CheckpointState = Value
+ val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
+}
+
+/**
+ * This class contains all the information related to RDD checkpointing. Each instance of this class
+ * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as,
+ * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations
+ * of the checkpointed RDD.
+ */
+private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
+extends Logging with Serializable {
+
+ import CheckpointState._
+
+ // The checkpoint state of the associated RDD.
+ var cpState = Initialized
+
+ // The file to which the associated RDD has been checkpointed to
+ @transient var cpFile: Option[String] = None
+
+ // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
+ @transient var cpRDD: Option[RDD[T]] = None
+
+ // Mark the RDD for checkpointing
+ def markForCheckpoint() {
+ RDDCheckpointData.synchronized {
+ if (cpState == Initialized) cpState = MarkedForCheckpoint
+ }
+ }
+
+ // Is the RDD already checkpointed
+ def isCheckpointed(): Boolean = {
+ RDDCheckpointData.synchronized { cpState == Checkpointed }
+ }
+
+ // Get the file to which this RDD was checkpointed to as an Option
+ def getCheckpointFile(): Option[String] = {
+ RDDCheckpointData.synchronized { cpFile }
+ }
+
+ // Do the checkpointing of the RDD. Called after the first job using that RDD is over.
+ def doCheckpoint() {
+ // If it is marked for checkpointing AND checkpointing is not already in progress,
+ // then set it to be in progress, else return
+ RDDCheckpointData.synchronized {
+ if (cpState == MarkedForCheckpoint) {
+ cpState = CheckpointingInProgress
+ } else {
+ return
+ }
+ }
+
+ // Save to file, and reload it as an RDD
+ val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
+ val newRDD = new CheckpointRDD[T](rdd.context, path)
+
+ // Change the dependencies and splits of the RDD
+ RDDCheckpointData.synchronized {
+ cpFile = Some(path)
+ cpRDD = Some(newRDD)
+ rdd.changeDependencies(newRDD)
+ cpState = Checkpointed
+ RDDCheckpointData.clearTaskCaches()
+ logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
+ }
+ }
+
+ // Get preferred location of a split after checkpointing
+ def getPreferredLocations(split: Split) = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.preferredLocations(split)
+ }
+ }
+
+ def getSplits: Array[Split] = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.splits
+ }
+ }
+
+ // Get iterator. This is called at the worker nodes.
+ def iterator(split: Split, context: TaskContext): Iterator[T] = {
+ rdd.firstParent[T].iterator(split, context)
+ }
+}
+
+private[spark] object RDDCheckpointData {
+ def clearTaskCaches() {
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ }
+}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 70257193cf..7c5cbc8870 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -37,12 +37,8 @@ import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import spark.rdd.HadoopRDD
-import spark.rdd.NewHadoopRDD
-import spark.rdd.UnionRDD
-import spark.scheduler.ShuffleMapTask
-import spark.scheduler.DAGScheduler
-import spark.scheduler.TaskScheduler
+import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD}
+import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
@@ -376,6 +372,13 @@ class SparkContext(
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes))
}
+
+ protected[spark] def checkpointFile[T: ClassManifest](
+ path: String
+ ): RDD[T] = {
+ new CheckpointRDD[T](this, path)
+ }
+
/** Build the union of a list of RDDs. */
def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
@@ -430,8 +433,9 @@ class SparkContext(
}
addedFiles(key) = System.currentTimeMillis
- // Fetch the file locally in case the task is executed locally
- val filename = new File(path.split("/").last)
+ // Fetch the file locally in case a job is executed locally.
+ // Jobs that run through LocalScheduler will already fetch the required dependencies,
+ // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
Utils.fetchFile(path, new File("."))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
@@ -448,11 +452,10 @@ class SparkContext(
}
/**
- * Clear the job's list of files added by `addFile` so that they do not get donwloaded to
+ * Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
def clearFiles() {
- addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedFiles.clear()
}
@@ -476,7 +479,6 @@ class SparkContext(
* any new nodes.
*/
def clearJars() {
- addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedJars.clear()
}
@@ -485,16 +487,19 @@ class SparkContext(
if (dagScheduler != null) {
dagScheduler.stop()
dagScheduler = null
+ taskScheduler = null
+ // TODO: Cache.stop()?
+ env.stop()
+ // Clean up locally linked files
+ clearFiles()
+ clearJars()
+ SparkEnv.set(null)
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ logInfo("Successfully stopped SparkContext")
+ } else {
+ logInfo("SparkContext already stopped")
}
- taskScheduler = null
- // TODO: Cache.stop()?
- env.stop()
- // Clean up locally linked files
- clearFiles()
- clearJars()
- SparkEnv.set(null)
- ShuffleMapTask.clearCache()
- logInfo("Successfully stopped SparkContext")
}
/**
@@ -629,10 +634,6 @@ class SparkContext(
*/
object SparkContext {
- // TODO: temporary hack for using HDFS as input in streaing
- var inputFile: String = null
- var idealPartitions: Int = 1
-
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
@@ -728,9 +729,6 @@ object SparkContext {
/** Find the JAR that contains the class of a particular object */
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
-
- implicit def rddToWeakRefRDD[T: ClassManifest](rdd: RDD[T]) = new WeakReference(rdd)
-
}
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 63e17ee4e3..d08921b25f 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -9,6 +9,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.io.Source
+import com.google.common.io.Files
/**
* Various utility methods used by Spark.
@@ -127,31 +128,53 @@ private object Utils extends Logging {
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ *
+ * Throws SparkException if the target file already exists and has different contents than
+ * the requested file.
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
+ val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))
+ val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
uri.getScheme match {
case "http" | "https" | "ftp" =>
- logInfo("Fetching " + url + " to " + targetFile)
+ logInfo("Fetching " + url + " to " + tempFile)
val in = new URL(url).openStream()
- val out = new FileOutputStream(targetFile)
+ val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
case "file" | null =>
- // Remove the file if it already exists
- targetFile.delete()
- // Symlink the file locally.
- if (uri.isAbsolute) {
- // url is absolute, i.e. it starts with "file:///". Extract the source
- // file's absolute path from the url.
- val sourceFile = new File(uri)
- logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ val sourceFile = if (uri.isAbsolute) {
+ new File(uri)
+ } else {
+ new File(url)
+ }
+ if (targetFile.exists && !Files.equal(sourceFile, targetFile)) {
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
} else {
- // url is not absolute, i.e. itself is the path to the source file.
- logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(url, targetFile.getAbsolutePath)
+ // Remove the file if it already exists
+ targetFile.delete()
+ // Symlink the file locally.
+ if (uri.isAbsolute) {
+ // url is absolute, i.e. it starts with "file:///". Extract the source
+ // file's absolute path from the url.
+ val sourceFile = new File(uri)
+ logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ } else {
+ // url is not absolute, i.e. itself is the path to the source file.
+ logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(url, targetFile.getAbsolutePath)
+ }
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
@@ -159,8 +182,15 @@ private object Utils extends Logging {
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
- val out = new FileOutputStream(targetFile)
+ val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
}
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index edbb187b1b..b7725313c4 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -301,6 +301,40 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* (in that order of preference). If neither of these is set, return None.
*/
def getSparkHome(): Option[String] = sc.getSparkHome()
+
+ /**
+ * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
+ def addFile(path: String) {
+ sc.addFile(path)
+ }
+
+ /**
+ * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
+ def addJar(path: String) {
+ sc.addJar(path)
+ }
+
+ /**
+ * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearJars() {
+ sc.clearJars()
+ }
+
+ /**
+ * Clear the job's list of files added by `addFile` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearFiles() {
+ sc.clearFiles()
+ }
}
object JavaSparkContext {
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index 61bc5c90ba..b1095a52b4 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,10 +1,8 @@
package spark.rdd
import scala.collection.mutable.HashMap
-
import spark.{RDD, SparkContext, SparkEnv, Split, TaskContext}
-
private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
}
@@ -14,7 +12,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
extends RDD[T](sc, Nil) {
@transient
- val splits_ = (0 until blockIds.size).map(i => {
+ var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
}).toArray
@@ -26,7 +24,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
HashMap(blockIds.zip(locations):_*)
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
@@ -38,12 +36,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
}
}
- override def preferredLocations(split: Split) = {
- if (isCheckpointed) {
- checkpointRDD.preferredLocations(split)
- } else {
- locations_(split.asInstanceOf[BlockRDDSplit].blockId)
- }
+ override def getPreferredLocations(split: Split) =
+ locations_(split.asInstanceOf[BlockRDDSplit].blockId)
+
+ override def clearDependencies() {
+ splits_ = null
}
}
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index bc11b60e05..79e7c24e7c 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,13 +1,28 @@
package spark.rdd
-import java.lang.ref.WeakReference
-
+import java.io.{ObjectOutputStream, IOException}
import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext}
private[spark]
-class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
+class CartesianSplit(
+ idx: Int,
+ @transient rdd1: RDD[_],
+ @transient rdd2: RDD[_],
+ s1Index: Int,
+ s2Index: Int
+ ) extends Split {
+ var s1 = rdd1.splits(s1Index)
+ var s2 = rdd2.splits(s2Index)
override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ s1 = rdd1.splits(s1Index)
+ s2 = rdd2.splits(s2Index)
+ oos.defaultWriteObject()
+ }
}
private[spark]
@@ -26,20 +41,16 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
val idx = s1.index * numSplitsInRdd2 + s2.index
- array(idx) = new CartesianSplit(idx, s1, s2)
+ array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index)
}
array
}
- override def splits = splits_
+ override def getSplits = splits_
- override def preferredLocations(split: Split) = {
- if (isCheckpointed) {
- checkpointRDD.preferredLocations(split)
- } else {
- val currSplit = split.asInstanceOf[CartesianSplit]
- rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
- }
+ override def getPreferredLocations(split: Split) = {
+ val currSplit = split.asInstanceOf[CartesianSplit]
+ rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
}
override def compute(split: Split, context: TaskContext) = {
@@ -57,11 +68,11 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
}
)
- override def dependencies = deps_
+ override def getDependencies = deps_
- override protected def changeDependencies(newRDD: RDD[_]) {
- deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]]))
- splits_ = newRDD.splits
+ override def clearDependencies() {
+ deps_ = Nil
+ splits_ = null
rdd1 = null
rdd2 = null
}
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
new file mode 100644
index 0000000000..86c63ca2f4
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -0,0 +1,128 @@
+package spark.rdd
+
+import spark._
+import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.{NullWritable, BytesWritable}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.fs.Path
+import java.io.{File, IOException, EOFException}
+import java.text.NumberFormat
+
+private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
+ override val index: Int = idx
+}
+
+/**
+ * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
+ */
+private[spark]
+class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
+ extends RDD[T](sc, Nil) {
+
+ @transient val path = new Path(checkpointPath)
+ @transient val fs = path.getFileSystem(new Configuration())
+
+ @transient val splits_ : Array[Split] = {
+ val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
+ splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
+ }
+
+ checkpointData = Some(new RDDCheckpointData[T](this))
+ checkpointData.get.cpFile = Some(checkpointPath)
+
+ override def getSplits = splits_
+
+ override def getPreferredLocations(split: Split): Seq[String] = {
+ val status = fs.getFileStatus(path)
+ val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+ locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ }
+
+ override def compute(split: Split, context: TaskContext): Iterator[T] = {
+ CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context)
+ }
+
+ override def checkpoint() {
+ // Do nothing. Hadoop RDD should not be checkpointed.
+ }
+}
+
+private[spark] object CheckpointRDD extends Logging {
+
+ def splitIdToFileName(splitId: Int): String = {
+ val numfmt = NumberFormat.getInstance()
+ numfmt.setMinimumIntegerDigits(5)
+ numfmt.setGroupingUsed(false)
+ "part-" + numfmt.format(splitId)
+ }
+
+ def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
+ val outputDir = new Path(path)
+ val fs = outputDir.getFileSystem(new Configuration())
+
+ val finalOutputName = splitIdToFileName(context.splitId)
+ val finalOutputPath = new Path(outputDir, finalOutputName)
+ val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
+
+ if (fs.exists(tempOutputPath)) {
+ throw new IOException("Checkpoint failed: temporary path " +
+ tempOutputPath + " already exists")
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
+ val fileOutputStream = if (blockSize < 0) {
+ fs.create(tempOutputPath, false, bufferSize)
+ } else {
+ // This is mainly for testing purpose
+ fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+ }
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ serializeStream.writeAll(iterator)
+ fileOutputStream.close()
+
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.delete(finalOutputPath, true)) {
+ throw new IOException("Checkpoint failed: failed to delete earlier output of task "
+ + context.attemptId);
+ }
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ throw new IOException("Checkpoint failed: failed to save output of task: "
+ + context.attemptId)
+ }
+ }
+ }
+
+ def readFromFile[T](path: String, context: TaskContext): Iterator[T] = {
+ val inputPath = new Path(path)
+ val fs = inputPath.getFileSystem(new Configuration())
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val fileInputStream = fs.open(inputPath, bufferSize)
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => deserializeStream.close())
+
+ deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+ }
+
+ // Test whether CheckpointRDD generate expected number of splits despite
+ // each split file having multiple blocks. This needs to be run on a
+ // cluster (mesos or standalone) using HDFS.
+ def main(args: Array[String]) {
+ import spark._
+
+ val Array(cluster, hdfsPath) = args
+ val sc = new SparkContext(cluster, "CheckpointRDD Test")
+ val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
+ val path = new Path(hdfsPath, "temp")
+ val fs = path.getFileSystem(new Configuration())
+ sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _)
+ val cpRDD = new CheckpointRDD[Int](sc, path.toString)
+ assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same")
+ assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same")
+ fs.delete(path)
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index ef8673909b..759bea5e9d 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -11,16 +11,17 @@ import spark.{Dependency, OneToOneDependency, ShuffleDependency}
private[spark] sealed trait CoGroupSplitDep extends Serializable
-private[spark]
-case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null)
- extends CoGroupSplitDep {
+private[spark] case class NarrowCoGroupSplitDep(
+ rdd: RDD[_],
+ splitIndex: Int,
+ var split: Split
+ ) extends CoGroupSplitDep {
+
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
- rdd.synchronized {
- // Update the reference to parent split at the time of task serialization
- split = rdd.splits(splitIndex)
- oos.defaultWriteObject()
- }
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.splits(splitIndex)
+ oos.defaultWriteObject()
}
}
@@ -39,7 +40,6 @@ private[spark] class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging {
@@ -49,19 +49,19 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
- val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- if (mapSideCombinedRDD.partitioner == Some(part)) {
- logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD)
- deps += new OneToOneDependency(mapSideCombinedRDD)
+ if (rdd.partitioner == Some(part)) {
+ logInfo("Adding one-to-one dependency with " + rdd)
+ deps += new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
+ val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
}
}
deps.toList
}
- override def dependencies = deps_
+ override def getDependencies = deps_
@transient
var splits_ : Array[Split] = {
@@ -72,15 +72,15 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
case _ =>
- new NarrowCoGroupSplitDep(r, i): CoGroupSplitDep
+ new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep
}
}.toList)
}
array
}
- override def splits = splits_
-
+ override def getSplits = splits_
+
override val partitioner = Some(part)
override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
@@ -111,9 +111,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
map.iterator
}
- override protected def changeDependencies(newRDD: RDD[_]) {
- deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]]))
- splits_ = newRDD.splits
+ override def clearDependencies() {
+ deps_ = null
+ splits_ = null
rdds = null
}
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index c5e2300d26..167755bbba 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,11 +1,22 @@
package spark.rdd
-import java.lang.ref.WeakReference
-
import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+private[spark] case class CoalescedRDDSplit(
+ index: Int,
+ @transient rdd: RDD[_],
+ parentsIndices: Array[Int]
+ ) extends Split {
+ var parents: Seq[Split] = parentsIndices.map(rdd.splits(_))
-private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ parents = parentsIndices.map(rdd.splits(_))
+ oos.defaultWriteObject()
+ }
+}
/**
* Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
@@ -23,17 +34,17 @@ class CoalescedRDD[T: ClassManifest](
@transient var splits_ : Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
- prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) }
+ prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
} else {
(0 until maxPartitions).map { i =>
val rangeStart = (i * prevSplits.length) / maxPartitions
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
- new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd))
+ new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray)
}.toArray
}
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
@@ -44,15 +55,15 @@ class CoalescedRDD[T: ClassManifest](
var deps_ : List[Dependency[_]] = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
- splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index)
+ splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
}
)
- override def dependencies = deps_
+ override def getDependencies() = deps_
- override protected def changeDependencies(newRDD: RDD[_]) {
- deps_ = List(new OneToOneDependency(newRDD))
- splits_ = newRDD.splits
+ override def clearDependencies() {
+ deps_ = Nil
+ splits_ = null
prev = null
}
}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index 70c4be7903..b80e9bc07b 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -1,14 +1,14 @@
package spark.rdd
-import java.lang.ref.WeakReference
-
import spark.{OneToOneDependency, RDD, Split, TaskContext}
-private[spark]
-class FilteredRDD[T: ClassManifest](prev: WeakReference[RDD[T]], f: T => Boolean)
- extends RDD[T](prev.get) {
+private[spark] class FilteredRDD[T: ClassManifest](
+ prev: RDD[T],
+ f: T => Boolean)
+ extends RDD[T](prev) {
+
+ override def getSplits = firstParent[T].splits
- override def splits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
firstParent[T].iterator(split, context).filter(f)
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 1ebbb4c9bd..1b604c66e2 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -1,17 +1,16 @@
package spark.rdd
-import java.lang.ref.WeakReference
-
import spark.{RDD, Split, TaskContext}
private[spark]
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
- prev: WeakReference[RDD[T]],
+ prev: RDD[T],
f: T => TraversableOnce[U])
- extends RDD[U](prev.get) {
+ extends RDD[U](prev) {
+
+ override def getSplits = firstParent[T].splits
- override def splits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
firstParent[T].iterator(split, context).flatMap(f)
}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index 43661ae3f8..051bffed19 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -1,13 +1,12 @@
package spark.rdd
-import java.lang.ref.WeakReference
-
import spark.{RDD, Split, TaskContext}
-private[spark]
-class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]])
- extends RDD[Array[T]](prev.get) {
- override def splits = firstParent[T].splits
+private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
+ extends RDD[Array[T]](prev) {
+
+ override def getSplits = firstParent[T].splits
+
override def compute(split: Split, context: TaskContext) =
Array(firstParent[T].iterator(split, context).toArray).iterator
}
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index 7b5f8ac3e9..f547f53812 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -22,9 +22,8 @@ import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskCo
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
- extends Split
- with Serializable {
-
+ extends Split {
+
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -64,7 +63,7 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopSplit]
@@ -110,11 +109,13 @@ class HadoopRDD[K, V](
}
}
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
// TODO: Filtering out "localhost" in case of file:// URLs
val hadoopSplit = split.asInstanceOf[HadoopSplit]
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
}
- override val isCheckpointable = false
+ override def checkpoint() {
+ // Do nothing. Hadoop RDD should not be checkpointed.
+ }
}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index 991f4be73f..073f7d7d2a 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -1,24 +1,20 @@
package spark.rdd
-
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
-import java.lang.ref.WeakReference
-
import spark.{RDD, Split, TaskContext}
private[spark]
class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
- prev: WeakReference[RDD[T]],
+ prev: RDD[T],
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false)
- extends RDD[U](prev.get) {
+ extends RDD[U](prev) {
+
+ override val partitioner =
+ if (preservesPartitioning) firstParent[T].partitioner else None
- override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
+ override def getSplits = firstParent[T].splits
- override def splits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
f(firstParent[T].iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index e2e7753cde..2ddc3d01b6 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -1,7 +1,5 @@
package spark.rdd
-import java.lang.ref.WeakReference
-
import spark.{RDD, Split, TaskContext}
@@ -12,13 +10,15 @@ import spark.{RDD, Split, TaskContext}
*/
private[spark]
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
- prev: WeakReference[RDD[T]],
+ prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean)
- extends RDD[U](prev.get) {
+ preservesPartitioning: Boolean
+ ) extends RDD[U](prev) {
+
+ override def getSplits = firstParent[T].splits
+
+ override val partitioner = if (preservesPartitioning) prev.partitioner else None
- override val partitioner = if (preservesPartitioning) prev.get.partitioner else None
- override def splits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
f(split.index, firstParent[T].iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 986cf35291..c6ceb272cd 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -1,17 +1,15 @@
package spark.rdd
-import java.lang.ref.WeakReference
-
import spark.{RDD, Split, TaskContext}
-
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](
- prev: WeakReference[RDD[T]],
+ prev: RDD[T],
f: T => U)
- extends RDD[U](prev.get) {
+ extends RDD[U](prev) {
+
+ override def getSplits = firstParent[T].splits
- override def splits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
firstParent[T].iterator(split, context).map(f)
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index c7cc8d4685..bb22db073c 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -52,7 +52,7 @@ class NewHadoopRDD[K, V](
result
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
@@ -87,10 +87,8 @@ class NewHadoopRDD[K, V](
}
}
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
val theSplit = split.asInstanceOf[NewHadoopSplit]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
}
-
- override val isCheckpointable = false
}
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 076c6a64a0..6631f83510 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -1,7 +1,6 @@
package spark.rdd
import java.io.PrintWriter
-import java.lang.ref.WeakReference
import java.util.StringTokenizer
import scala.collection.Map
@@ -17,18 +16,18 @@ import spark.{RDD, SparkEnv, Split, TaskContext}
* (printing them one per line) and returns the output as a collection of strings.
*/
class PipedRDD[T: ClassManifest](
- prev: WeakReference[RDD[T]],
+ prev: RDD[T],
command: Seq[String],
envVars: Map[String, String])
- extends RDD[String](prev.get) {
+ extends RDD[String](prev) {
- def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map())
+ def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command))
+ def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
- override def splits = firstParent[T].splits
+ override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 0dc83c127f..1bc9c96112 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -1,6 +1,5 @@
package spark.rdd
-import java.lang.ref.WeakReference
import java.util.Random
import cern.jet.random.Poisson
@@ -14,21 +13,21 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali
}
class SampledRDD[T: ClassManifest](
- prev: WeakReference[RDD[T]],
- withReplacement: Boolean,
+ prev: RDD[T],
+ withReplacement: Boolean,
frac: Double,
seed: Int)
- extends RDD[T](prev.get) {
+ extends RDD[T](prev) {
@transient
- val splits_ = {
+ var splits_ : Array[Split] = {
val rg = new Random(seed)
firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
}
- override def splits = splits_.asInstanceOf[Array[Split]]
+ override def getSplits = splits_.asInstanceOf[Array[Split]]
- override def preferredLocations(split: Split) =
+ override def getPreferredLocations(split: Split) =
firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
override def compute(splitIn: Split, context: TaskContext) = {
@@ -50,4 +49,8 @@ class SampledRDD[T: ClassManifest](
firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
}
}
+
+ override def clearDependencies() {
+ splits_ = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 7d592ffc5e..1b219473e0 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,9 +1,7 @@
package spark.rdd
-import java.lang.ref.WeakReference
-
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
-
+import spark.SparkContext._
private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
@@ -12,30 +10,29 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
/**
* The resulting RDD from a shuffle (e.g. repartitioning of data).
- * @param parent the parent RDD.
+ * @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
- @transient prev: WeakReference[RDD[(K, V)]],
+ prev: RDD[(K, V)],
part: Partitioner)
- extends RDD[(K, V)](prev.get.context, List(new ShuffleDependency(prev.get, part))) {
+ extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
override val partitioner = Some(part)
@transient
var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
- override def splits = splits_
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
}
- override def changeDependencies(newRDD: RDD[_]) {
- dependencies_ = Nil
- splits_ = newRDD.splits
+ override def clearDependencies() {
+ splits_ = null
}
}
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 21965d3f5d..24a085df02 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -1,20 +1,26 @@
package spark.rdd
-import java.lang.ref.WeakReference
import scala.collection.mutable.ArrayBuffer
-import spark.{Dependency, OneToOneDependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+ extends Split {
-private[spark] class UnionSplit[T: ClassManifest](
- idx: Int,
- rdd: RDD[T],
- split: Split)
- extends Split
- with Serializable {
+ var split: Split = rdd.splits(splitIndex)
def iterator(context: TaskContext) = rdd.iterator(split, context)
+
def preferredLocations() = rdd.preferredLocations(split)
+
override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.splits(splitIndex)
+ oos.defaultWriteObject()
+ }
}
class UnionRDD[T: ClassManifest](
@@ -27,13 +33,13 @@ class UnionRDD[T: ClassManifest](
val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0
for (rdd <- rdds; split <- rdd.splits) {
- array(pos) = new UnionSplit(pos, rdd, split)
+ array(pos) = new UnionSplit(pos, rdd, split.index)
pos += 1
}
array
}
- override def splits = splits_
+ override def getSplits = splits_
@transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
@@ -45,22 +51,17 @@ class UnionRDD[T: ClassManifest](
deps.toList
}
- override def dependencies = deps_
+ override def getDependencies = deps_
override def compute(s: Split, context: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(context)
- override def preferredLocations(s: Split): Seq[String] = {
- if (isCheckpointed) {
- checkpointRDD.preferredLocations(s)
- } else {
- s.asInstanceOf[UnionSplit[T]].preferredLocations()
- }
- }
+ override def getPreferredLocations(s: Split): Seq[String] =
+ s.asInstanceOf[UnionSplit[T]].preferredLocations()
- override protected def changeDependencies(newRDD: RDD[_]) {
- deps_ = List(new OneToOneDependency(newRDD))
- splits_ = newRDD.splits
+ override def clearDependencies() {
+ deps_ = null
+ splits_ = null
rdds = null
}
}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 33b64e2d24..16e6cc0f1b 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,55 +1,66 @@
package spark.rdd
import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
idx: Int,
- rdd1: RDD[T],
- rdd2: RDD[U],
- split1: Split,
- split2: Split)
- extends Split
- with Serializable {
+ @transient rdd1: RDD[T],
+ @transient rdd2: RDD[U]
+ ) extends Split {
- def iterator(context: TaskContext): Iterator[(T, U)] =
- rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+ var split1 = rdd1.splits(idx)
+ var split2 = rdd1.splits(idx)
+ override val index: Int = idx
- def preferredLocations(): Seq[String] =
- rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+ def splits = (split1, split2)
- override val index: Int = idx
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split1 = rdd1.splits(idx)
+ split2 = rdd2.splits(idx)
+ oos.defaultWriteObject()
+ }
}
class ZippedRDD[T: ClassManifest, U: ClassManifest](
sc: SparkContext,
- @transient rdd1: RDD[T],
- @transient rdd2: RDD[U])
- extends RDD[(T, U)](sc, Nil)
+ var rdd1: RDD[T],
+ var rdd2: RDD[U])
+ extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
with Serializable {
// TODO: FIX THIS.
@transient
- val splits_ : Array[Split] = {
+ var splits_ : Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
val array = new Array[Split](rdd1.splits.size)
for (i <- 0 until rdd1.splits.size) {
- array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i))
+ array(i) = new ZippedSplit(i, rdd1, rdd2)
}
array
}
- override def splits = splits_
+ override def getSplits = splits_
- @transient
- override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
+ override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
+ val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+ }
- override def compute(s: Split, context: TaskContext): Iterator[(T, U)] =
- s.asInstanceOf[ZippedSplit[T, U]].iterator(context)
+ override def getPreferredLocations(s: Split): Seq[String] = {
+ val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+ }
- override def preferredLocations(s: Split): Seq[String] =
- s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
+ override def clearDependencies() {
+ splits_ = null
+ rdd1 = null
+ rdd2 = null
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 9387ba19a3..59f2099e91 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -599,15 +599,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
def cleanup(cleanupTime: Long) {
var sizeBefore = idToStage.size
- idToStage.cleanup(cleanupTime)
+ idToStage.clearOldValues(cleanupTime)
logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)
sizeBefore = shuffleToMapStage.size
- shuffleToMapStage.cleanup(cleanupTime)
+ shuffleToMapStage.clearOldValues(cleanupTime)
logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size)
sizeBefore = pendingTasks.size
- pendingTasks.cleanup(cleanupTime)
+ pendingTasks.clearOldValues(cleanupTime)
logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index e492279b4e..74a63c1af1 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -1,17 +1,74 @@
package spark.scheduler
import spark._
+import java.io._
+import util.{MetadataCleaner, TimeStampedHashMap}
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+private[spark] object ResultTask {
+
+ // A simple map between the stage id to the serialized byte array of a task.
+ // Served as a cache for task serialization because serialization can be
+ // expensive on the master node if it needs to launch thousands of tasks.
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
+ synchronized {
+ val old = serializedInfoCache.get(stageId).orNull
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objOut = ser.serializeStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(func)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ return bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
+ synchronized {
+ val loader = Thread.currentThread.getContextClassLoader
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objIn = ser.deserializeStream(in)
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
+ return (rdd, func)
+ }
+ }
+
+ def clearCache() {
+ synchronized {
+ serializedInfoCache.clear()
+ }
+ }
+}
+
private[spark] class ResultTask[T, U](
stageId: Int,
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- val partition: Int,
+ var rdd: RDD[T],
+ var func: (TaskContext, Iterator[T]) => U,
+ var partition: Int,
@transient locs: Seq[String],
val outputId: Int)
- extends Task[U](stageId) {
+ extends Task[U](stageId) with Externalizable {
- val split = rdd.splits(partition)
+ def this() = this(0, null, null, 0, null, 0)
+
+ var split = if (rdd == null) {
+ null
+ } else {
+ rdd.splits(partition)
+ }
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
@@ -23,4 +80,31 @@ private[spark] class ResultTask[T, U](
override def preferredLocations: Seq[String] = locs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+
+ override def writeExternal(out: ObjectOutput) {
+ RDDCheckpointData.synchronized {
+ split = rdd.splits(partition)
+ out.writeInt(stageId)
+ val bytes = ResultTask.serializeInfo(
+ stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeInt(outputId)
+ out.writeObject(split)
+ }
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_.asInstanceOf[RDD[T]]
+ func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
+ partition = in.readInt()
+ val outputId = in.readInt()
+ split = in.readObject().asInstanceOf[Split]
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 7fdc178d4b..19f5328eee 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -23,7 +23,7 @@ private[spark] object ShuffleMapTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.cleanup)
+ val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
@@ -90,13 +90,16 @@ private[spark] class ShuffleMapTask(
}
override def writeExternal(out: ObjectOutput) {
- out.writeInt(stageId)
- val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
- out.writeInt(bytes.length)
- out.write(bytes)
- out.writeInt(partition)
- out.writeLong(generation)
- out.writeObject(split)
+ RDDCheckpointData.synchronized {
+ split = rdd.splits(partition)
+ out.writeInt(stageId)
+ val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeLong(generation)
+ out.writeObject(split)
+ }
}
override def readExternal(in: ObjectInput) {
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index eb20fe41b2..dff550036d 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -81,7 +81,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished task " + idInJob)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates)
+
+ // If the threadpool has not already been shutdown, notify DAGScheduler
+ if (!Thread.currentThread().isInterrupted)
+ listener.taskEnded(task, Success, resultToReturn, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -91,7 +94,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
- listener.taskEnded(task, new ExceptionFailure(t), null, null)
+ if (!Thread.currentThread().isInterrupted)
+ listener.taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}
@@ -108,22 +112,24 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentFiles(name) = timestamp
- }
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
- if (!classLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- classLoader.addURL(url)
+ synchronized {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File("."))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File("."))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(".", localName).toURI.toURL
+ if (!classLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ classLoader.addURL(url)
+ }
}
}
}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
index 7e785182ea..bb7c5c01c8 100644
--- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
@@ -7,7 +7,7 @@ import scala.collection.mutable.Map
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
* time stamp along with each key-value pair. Key-value pairs that are older than a particular
- * threshold time can them be removed using the cleanup method. This is intended to be a drop-in
+ * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in
* replacement of scala.collection.mutable.HashMap.
*/
class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
@@ -74,7 +74,10 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
}
}
- def cleanup(threshTime: Long) {
+ /**
+ * Removes old key-value pairs that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
val iterator = internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala
index 539dd75844..5f1cc93752 100644
--- a/core/src/main/scala/spark/util/TimeStampedHashSet.scala
+++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala
@@ -52,7 +52,10 @@ class TimeStampedHashSet[A] extends Set[A] {
}
}
- def cleanup(threshTime: Long) {
+ /**
+ * Removes old values that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
val iterator = internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties
index 4c99e450bc..6ec89c0184 100644
--- a/core/src/test/resources/log4j.properties
+++ b/core/src/test/resources/log4j.properties
@@ -1,8 +1,8 @@
-# Set everything to be logged to the console
+# Set everything to be logged to the file core/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-log4j.appender.file.file=spark-tests.log
+log4j.appender.file.file=core/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 302a95db66..51573254ca 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -2,17 +2,16 @@ package spark
import org.scalatest.{BeforeAndAfter, FunSuite}
import java.io.File
-import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD}
+import spark.rdd._
import spark.SparkContext._
import storage.StorageLevel
-import java.util.concurrent.Semaphore
-import collection.mutable.ArrayBuffer
class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
initLogging()
var sc: SparkContext = _
var checkpointDir: File = _
+ val partitioner = new HashPartitioner(2)
before {
checkpointDir = File.createTempFile("temp", "")
@@ -35,14 +34,30 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
}
}
+ test("RDDs with one-to-one dependencies") {
+ testCheckpointing(_.map(x => x.toString))
+ testCheckpointing(_.flatMap(x => 1 to x))
+ testCheckpointing(_.filter(_ % 2 == 0))
+ testCheckpointing(_.sample(false, 0.5, 0))
+ testCheckpointing(_.glom())
+ testCheckpointing(_.mapPartitions(_.map(_.toString)))
+ testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
+ (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
+ testCheckpointing(_.pipe(Seq("cat")))
+ }
+
test("ParallelCollection") {
- val parCollection = sc.makeRDD(1 to 4)
+ val parCollection = sc.makeRDD(1 to 4, 2)
+ val numSplits = parCollection.splits.size
parCollection.checkpoint()
assert(parCollection.dependencies === Nil)
val result = parCollection.collect()
- sleep(parCollection) // slightly extra time as loading classes for the first can take some time
- assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result)
+ assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
assert(parCollection.dependencies != Nil)
+ assert(parCollection.splits.length === numSplits)
+ assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList)
assert(parCollection.collect() === result)
}
@@ -51,163 +66,292 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
val blockRDD = new BlockRDD[String](sc, Array(blockId))
+ val numSplits = blockRDD.splits.size
blockRDD.checkpoint()
val result = blockRDD.collect()
- sleep(blockRDD)
- assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result)
+ assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
assert(blockRDD.dependencies != Nil)
+ assert(blockRDD.splits.length === numSplits)
+ assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList)
assert(blockRDD.collect() === result)
}
- test("RDDs with one-to-one dependencies") {
- testCheckpointing(_.map(x => x.toString))
- testCheckpointing(_.flatMap(x => 1 to x))
- testCheckpointing(_.filter(_ % 2 == 0))
- testCheckpointing(_.sample(false, 0.5, 0))
- testCheckpointing(_.glom())
- testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
- (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false))
- testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000)
- testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000)
- testCheckpointing(_.pipe(Seq("cat")))
- }
-
test("ShuffledRDD") {
- testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _))
+ testCheckpointing(rdd => {
+ // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
+ new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner)
+ })
}
test("UnionRDD") {
- testCheckpointing(_.union(sc.makeRDD(5 to 6, 4)))
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+
+ // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed.
+ // Current implementation of UnionRDD has transient reference to parent RDDs,
+ // so only the splits will reduce in serialized size, not the RDD.
+ testCheckpointing(_.union(otherRDD), false, true)
+ testParentCheckpointing(_.union(otherRDD), false, true)
}
test("CartesianRDD") {
- testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000)
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+ testCheckpointing(new CartesianRDD(sc, _, otherRDD))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
+
+ // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after
+ // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Note that this test is very specific to the current implementation of CartesianRDD.
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint // checkpoint that MappedRDD
+ val cartesian = new CartesianRDD(sc, ones, ones)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ cartesian.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ assert(
+ (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
+ (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
+ "CartesianRDD.parents not updated after parent RDD checkpointed"
+ )
}
test("CoalescedRDD") {
testCheckpointing(new CoalescedRDD(_, 2))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(new CoalescedRDD(_, 2), true, false)
+
+ // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after
+ // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Note that this test is very specific to the current implementation of CoalescedRDDSplits
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint // checkpoint that MappedRDD
+ val coalesced = new CoalescedRDD(ones, 2)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ coalesced.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ assert(
+ splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
+ "CoalescedRDDSplit.parents not updated after parent RDD checkpointed"
+ )
}
test("CoGroupedRDD") {
- val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1))
- testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2))
- testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2))
+ val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD()
+ testCheckpointing(rdd => {
+ CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
+ }, false, true)
- // Special test to make sure that the CoGroupSplit of CoGroupedRDD do not
- // hold on to the splits of its parent RDDs, as the splits of parent RDDs
- // may change while checkpointing. Rather the splits of parent RDDs must
- // be fetched at the time of serialization to ensure the latest splits to
- // be sent along with the task.
+ val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD()
+ testParentCheckpointing(rdd => {
+ CheckpointSuite.cogroup(
+ longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+ }
- val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
+ test("ZippedRDD") {
+ testCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+ // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of ZippedRDDSplit has transient references to parent RDDs,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+ }
- val ones = sc.parallelize(1 to 100, 1).map(x => (x,1))
- val reduced = ones.reduceByKey(_ + _)
- val seqOfCogrouped = new ArrayBuffer[RDD[(Int, Int)]]()
- seqOfCogrouped += reduced.cogroup(ones).mapValues[Int](add)
- for(i <- 1 to 10) {
- seqOfCogrouped += seqOfCogrouped.last.cogroup(ones).mapValues(add)
- }
- val finalCogrouped = seqOfCogrouped.last
- val intermediateCogrouped = seqOfCogrouped(5)
-
- val bytesBeforeCheckpoint = Utils.serialize(finalCogrouped.splits)
- intermediateCogrouped.checkpoint()
- finalCogrouped.count()
- sleep(intermediateCogrouped)
- val bytesAfterCheckpoint = Utils.serialize(finalCogrouped.splits)
- println("Before = " + bytesBeforeCheckpoint.size + ", after = " + bytesAfterCheckpoint.size)
- assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size,
- "CoGroupedSplits still holds on to the splits of its parent RDDs")
- }
- /*
/**
- * This test forces two ResultTasks of the same job to be launched before and after
- * the checkpointing of job's RDD is completed.
+ * Test checkpointing of the final RDD generated by the given operation. By default,
+ * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
+ * It can also test whether the size of serialized RDD splits has reduced after checkpointing or
+ * not, but this is not done by default as usually the splits do not refer to any RDD and
+ * therefore never store the lineage.
*/
- test("Threading - ResultTasks") {
- val op1 = (parCollection: RDD[Int]) => {
- parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) })
+ def testCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean = true,
+ testRDDSplitSize: Boolean = false
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.headOption.orNull
+ val rddType = operatedRDD.getClass.getSimpleName
+ val numSplits = operatedRDD.splits.length
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ operatedRDD.checkpoint()
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the checkpoint file has been created
+ assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result)
+
+ // Test whether dependencies have been changed from its earlier parent RDD
+ assert(operatedRDD.dependencies.head.rdd != parentRDD)
+
+ // Test whether the splits have been changed to the new Hadoop splits
+ assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList)
+
+ // Test whether the number of splits is same as before
+ assert(operatedRDD.splits.length === numSplits)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced. If the RDD
+ // does not have any dependency to another RDD (e.g., ParallelCollection,
+ // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing.
+ if (testRDDSize) {
+ logInfo("Size of " + rddType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing " +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
}
- val op2 = (firstRDD: RDD[(Int, Int)]) => {
- firstRDD.map(x => { println("2nd map running on " + x); Thread.sleep(500); x })
+
+ // Test whether serialized size of the splits has reduced. If the splits
+ // do not have any non-transient reference to another RDD or another RDD's splits, it
+ // does not refer to a lineage and therefore may not reduce in size after checkpointing.
+ // However, if the original splits before checkpointing do refer to a parent RDD, the splits
+ // must be forgotten after checkpointing (to remove all reference to parent RDDs) and
+ // replaced with the HadoopSplits of the checkpointed RDD.
+ if (testRDDSplitSize) {
+ logInfo("Size of " + rddType + " splits "
+ + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " splits did not reduce after checkpointing " +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
}
- testThreading(op1, op2)
}
/**
- * This test forces two ShuffleMapTasks of the same job to be launched before and after
- * the checkpointing of job's RDD is completed.
+ * Test whether checkpointing of the parent of the generated RDD also
+ * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
+ * RDDs splits. So even if the parent RDD is checkpointed and its splits changed,
+ * this RDD will remember the splits and therefore potentially the whole lineage.
*/
- test("Threading - ShuffleMapTasks") {
- val op1 = (parCollection: RDD[Int]) => {
- parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) })
+ def testParentCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean,
+ testRDDSplitSize: Boolean
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.head.rdd
+ val rddType = operatedRDD.getClass.getSimpleName
+ val parentRDDType = parentRDD.getClass.getSimpleName
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced because of its parent being
+ // checkpointed. If this RDD or its parent RDD do not have any dependency
+ // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may
+ // not reduce in size after checkpointing.
+ if (testRDDSize) {
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
}
- val op2 = (firstRDD: RDD[(Int, Int)]) => {
- firstRDD.groupByKey(2).map(x => { println("2nd map running on " + x); Thread.sleep(500); x })
+
+ // Test whether serialized size of the splits has reduced because of its parent being
+ // checkpointed. If the splits do not have any non-transient reference to another RDD
+ // or another RDD's splits, it does not refer to a lineage and therefore may not reduce
+ // in size after checkpointing. However, if the splits do refer to the *splits* of a parent
+ // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's
+ // splits must have changed after checkpointing.
+ if (testRDDSplitSize) {
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
}
- testThreading(op1, op2)
+
}
- */
- def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) {
- val parCollection = sc.makeRDD(1 to 4, 4)
- val operatedRDD = op(parCollection)
- operatedRDD.checkpoint()
- val parentRDD = operatedRDD.dependencies.head.rdd
- val result = operatedRDD.collect()
- sleep(operatedRDD)
- //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd )
- assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result)
- assert(operatedRDD.dependencies.head.rdd != parentRDD)
- assert(operatedRDD.collect() === result)
+ /**
+ * Generate an RDD with a long lineage of one-to-one dependencies.
+ */
+ def generateLongLineageRDD(): RDD[Int] = {
+ var rdd = sc.makeRDD(1 to 100, 4)
+ for (i <- 1 to 50) {
+ rdd = rdd.map(x => x + 1)
+ }
+ rdd
}
- /*
- def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) {
-
- val parCollection = sc.makeRDD(1 to 2, 2)
-
- // This is the RDD that is to be checkpointed
- val firstRDD = op1(parCollection)
- val parentRDD = firstRDD.dependencies.head.rdd
- firstRDD.checkpoint()
-
- // This the RDD that uses firstRDD. This is designed to launch a
- // ShuffleMapTask that uses firstRDD.
- val secondRDD = op2(firstRDD)
-
- // Starting first job, to initiate the checkpointing
- logInfo("\nLaunching 1st job to initiate checkpointing\n")
- firstRDD.collect()
-
- // Checkpointing has started but not completed yet
- Thread.sleep(100)
- assert(firstRDD.dependencies.head.rdd === parentRDD)
-
- // Starting second job; first task of this job will be
- // launched _before_ firstRDD is marked as checkpointed
- // and the second task will be launched _after_ firstRDD
- // is marked as checkpointed
- logInfo("\nLaunching 2nd job that is designed to launch tasks " +
- "before and after checkpointing is complete\n")
- val result = secondRDD.collect()
-
- // Check whether firstRDD has been successfully checkpointed
- assert(firstRDD.dependencies.head.rdd != parentRDD)
-
- logInfo("\nRecomputing 2nd job to verify the results of the previous computation\n")
- // Check whether the result in the previous job was correct or not
- val correctResult = secondRDD.collect()
- assert(result === correctResult)
- }
- */
- def sleep(rdd: RDD[_]) {
- val startTime = System.currentTimeMillis()
- val maxWaitTime = 5000
- while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) {
- Thread.sleep(50)
+
+ /**
+ * Generate an RDD with a long lineage specifically for CoGroupedRDD.
+ * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage
+ * and narrow dependency with this RDD. This method generate such an RDD by a sequence
+ * of cogroups and mapValues which creates a long lineage of narrow dependencies.
+ */
+ def generateLongLineageRDDForCoGroupedRDD() = {
+ val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
+
+ def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+
+ var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones)
+ for(i <- 1 to 10) {
+ cogrouped = cogrouped.mapValues(add).cogroup(ones)
}
- assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms")
+ cogrouped.mapValues(add)
+ }
+
+ /**
+ * Get serialized sizes of the RDD and its splits
+ */
+ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+ (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size)
+ }
+
+ /**
+ * Serialize and deserialize an object. This is useful to verify the objects
+ * contents after deserialization (e.g., the contents of an RDD split after
+ * it is sent to a slave along with a task)
+ */
+ def serializeDeserialize[T](obj: T): T = {
+ val bytes = Utils.serialize(obj)
+ Utils.deserialize[T](bytes)
}
}
+
+
+object CheckpointSuite {
+ // This is a custom cogroup function that does not use mapValues like
+ // the PairRDDFunctions.cogroup()
+ def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
+ //println("First = " + first + ", second = " + second)
+ new CoGroupedRDD[K](
+ Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]),
+ part
+ ).asInstanceOf[RDD[(K, Seq[Seq[V]])]]
+ }
+
+}
diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala
index 7c0334d957..dfa2de80e6 100644
--- a/core/src/test/scala/spark/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala
@@ -47,6 +47,8 @@ object TestObject {
val nums = sc.parallelize(Array(1, 2, 3, 4))
val answer = nums.map(_ + x).reduce(_ + _)
sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
return answer
}
}
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 46a0b68f89..33d5fc2d89 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -131,6 +131,17 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void lookup() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ Assert.assertEquals(2, categories.lookup("Oranges").size());
+ Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size());
+ }
+
+ @Test
public void groupBy() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
Function<Integer, Boolean> isOdd = new Function<Integer, Boolean>() {
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index 3dadc7acec..f09b602a7b 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -107,4 +107,25 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter {
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
}
+
+ test("partitioning Java arrays should fail") {
+ sc = new SparkContext("local", "test")
+ val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
+ val arrPairs: RDD[(Array[Int], Int)] =
+ sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
+
+ assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array"))
+ // We can't catch all usages of arrays, since they might occur inside other collections:
+ //assert(fails { arrPairs.distinct() })
+ assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
+ }
}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 4614901d78..e5a59dc7d6 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -8,9 +8,9 @@ import spark.rdd.CoalescedRDD
import SparkContext._
class RDDSuite extends FunSuite with BeforeAndAfter {
-
+
var sc: SparkContext = _
-
+
after {
if (sc != null) {
sc.stop()
@@ -24,6 +24,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
+ val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
+ assert(dups.distinct.count === 4)
+ assert(dups.distinct().collect === dups.distinct.collect)
+ assert(dups.distinct(2).collect === dups.distinct.collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@@ -97,6 +101,29 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
+ test("caching with failures") {
+ sc = new SparkContext("local", "test")
+ val onlySplit = new Split { override def index: Int = 0 }
+ var shouldFail = true
+ val rdd = new RDD[Int](sc, Nil) {
+ override def getSplits: Array[Split] = Array(onlySplit)
+ override val getDependencies = List[Dependency[_]]()
+ override def compute(split: Split, context: TaskContext): Iterator[Int] = {
+ if (shouldFail) {
+ throw new Exception("injected failure")
+ } else {
+ return Array(1, 2, 3, 4).iterator
+ }
+ }
+ }.cache()
+ val thrown = intercept[Exception]{
+ rdd.collect()
+ }
+ assert(thrown.getMessage.contains("injected failure"))
+ shouldFail = false
+ assert(rdd.collect().toList === List(1, 2, 3, 4))
+ }
+
test("coalesced RDDs") {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
@@ -134,7 +161,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
val zipped = nums.zip(nums.map(_ + 1.0))
assert(zipped.glom().map(_.toList).collect().toList ===
List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0))))
-
+
intercept[IllegalArgumentException] {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index d656b3e3de..a8be52f23e 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -47,11 +47,19 @@
<li><a href="quick-start.html">Quick Start</a></li>
<li><a href="scala-programming-guide.html">Scala</a></li>
<li><a href="java-programming-guide.html">Java</a></li>
- <li><a href="streaming-programming-guide.html">Spark Streaming (Alpha)</a></li>
+ <li><a href="streaming-programming-guide.html">Spark Streaming</a></li>
</ul>
</li>
- <li><a href="api/core/index.html">API (Scaladoc)</a></li>
+ <li class="dropdown">
+ <a href="#" class="dropdown-toggle" data-toggle="dropdown">API (Scaladoc)<b class="caret"></b></a>
+ <ul class="dropdown-menu">
+ <li><a href="api/core/index.html">Spark</a></li>
+ <li><a href="api/examples/index.html">Spark Examples</a></li>
+ <li><a href="api/streaming/index.html">Spark Streaming</a></li>
+ <li><a href="api/bagel/index.html">Bagel</a></li>
+ </ul>
+ </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a>
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index e61c105449..7654511eeb 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -2,7 +2,7 @@ require 'fileutils'
include FileUtils
if ENV['SKIP_SCALADOC'] != '1'
- projects = ["core", "examples", "repl", "bagel"]
+ projects = ["core", "examples", "repl", "bagel", "streaming"]
puts "Moving to project root and building scaladoc."
curr_dir = pwd
@@ -11,7 +11,7 @@ if ENV['SKIP_SCALADOC'] != '1'
puts "Running sbt/sbt doc from " + pwd + "; this may take a few minutes..."
puts `sbt/sbt doc`
- puts "moving back into docs dir."
+ puts "Moving back into docs dir."
cd("docs")
# Copy over the scaladoc from each project into the docs directory.
diff --git a/docs/api.md b/docs/api.md
index 43548b223c..3a13e30f82 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -5,6 +5,7 @@ title: Spark API documentation (Scaladoc)
Here you can find links to the Scaladoc generated for the Spark sbt subprojects. If the following links don't work, try running `sbt/sbt doc` from the Spark project home directory.
-- [Core](api/core/index.html)
-- [Examples](api/examples/index.html)
+- [Spark](api/core/index.html)
+- [Spark Examples](api/examples/index.html)
+- [Spark Streaming](api/streaming/index.html)
- [Bagel](api/bagel/index.html)
diff --git a/docs/configuration.md b/docs/configuration.md
index d8317ea97c..87cb4a6797 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -219,6 +219,17 @@ Apart from these, the following properties are also available, and may be useful
Port for the master to listen on.
</td>
</tr>
+<tr>
+ <td>spark.cleaner.delay</td>
+ <td>(disable)</td>
+ <td>
+ Duration (minutes) of how long Spark will remember any metadata (stages generated, tasks generated, etc.).
+ Periodic cleanups will ensure that metadata older than this duration will be forgetten. This is
+ useful for running Spark for many hours / days (for example, running 24/7 in case of Spark Streaming
+ applications). Note that any RDD that persists in memory for more than this duration will be cleared as well.
+ </td>
+</tr>
+
</table>
# Configuring Logging
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 90916545bc..05a88ce7bd 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -1,31 +1,48 @@
---
layout: global
-title: Streaming (Alpha) Programming Guide
+title: Spark Streaming Programming Guide
---
+
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
+# Overview
+A Spark Streaming application is very similar to a Spark application; it consists of a *driver program* that runs the user's `main` function and continuous executes various *parallel operations* on input streams of data. The main abstraction Spark Streaming provides is a *discretized stream* (DStream), which is a continuous sequence of RDDs (distributed collection of elements) representing a continuous stream of data. DStreams can created from live incoming data (such as data from a socket, Kafka, etc.) or it can be generated by transformation of existing DStreams using parallel operators like map, reduce, and window. The basic processing model is as follows:
+(i) While a Spark Streaming driver program is running, the system receives data from various sources and and divides the data into batches. Each batch of data is treated as a RDD, that is a immutable and parallel collection of data. These input data RDDs are automatically persisted in memory (serialized by default) and replicated to two nodes for fault-tolerance. This sequence of RDDs is collectively referred to as an InputDStream.
+(ii) Data received by InputDStreams are processed processed using DStream operations. Since all data is represented as RDDs and all DStream operations as RDD operations, data is automatically recovered in the event of node failures.
+
+This guide shows some how to start programming with DStreams.
+
# Initializing Spark Streaming
-The first thing a Spark Streaming program must do is create a `StreamingContext` object, which tells Spark how to access a cluster. A `StreamingContext` can be created from an existing `SparkContext`, or directly:
+The first thing a Spark Streaming program must do is create a `StreamingContext` object, which tells Spark how to access a cluster. A `StreamingContext` can be created by using
{% highlight scala %}
-new StreamingContext(master, jobName, [sparkHome], [jars])
-new StreamingContext(sparkContext)
+new StreamingContext(master, jobName, batchDuration)
{% endhighlight %}
-Once a context is instantiated, the batch interval must be set:
+The `master` parameter is either the [Mesos master URL](running-on-mesos.html) (for running on a cluster)or the special "local" string (for local mode) that is used to create a Spark Context. For more information about this please refer to the [Spark programming guide](scala-programming-guide.html). The `jobName` is the name of the streaming job, which is the same as the jobName used in SparkContext. It is used to identify this job in the Mesos web UI. The `batchDuration` is the size of the batches (as explained earlier). This must be set carefully such the cluster can keep up with the processing of the data streams. Starting with something conservative like 5 seconds maybe a good start. See [Performance Tuning](#setting-the-right-batch-size) section for a detailed discussion.
+This constructor creates a SparkContext object using the given `master` and `jobName` parameters. However, if you already have a SparkContext or you need to create a custom SparkContext by specifying list of JARs, then a StreamingContext can be created from the existing SparkContext, by using
{% highlight scala %}
-context.setBatchDuration(Milliseconds(2000))
+new StreamingContext(sparkContext, batchDuration)
{% endhighlight %}
-# DStreams - Discretized Streams
-The primary abstraction in Spark Streaming is a DStream. A DStream represents distributed collection which is computed periodically according to a specified batch interval. DStream's can be chained together to create complex chains of transformation on streaming data. DStreams can be created by operating on existing DStreams or from an input source. To creating DStreams from an input source, use the StreamingContext:
+
+# Attaching Input Sources - InputDStreams
+The StreamingContext is used to creating InputDStreams from input sources:
{% highlight scala %}
-context.neworkStream(host, port) // A stream that reads from a socket
-context.flumeStream(hosts, ports) // A stream populated by a Flume flow
+// Assuming ssc is the StreamingContext
+ssc.networkStream(hostname, port) // Creates a stream that uses a TCP socket to read data from hostname:port
+ssc.textFileStream(directory) // Creates a stream by monitoring and processing new files in a HDFS directory
{% endhighlight %}
-# DStream Operators
+A complete list of input sources is available in the [StreamingContext API documentation](api/streaming/index.html#spark.streaming.StreamingContext). Data received from these sources can be processed using DStream operations, which are explained next.
+
+
+
+# DStream Operations
Once an input stream has been created, you can transform it using _stream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the stream by writing data out to an external source.
## Transformations
@@ -73,20 +90,13 @@ DStreams support many of the transformations available on normal Spark RDD's:
<td> <b>cogroup</b>(<i>otherStream</i>, [<i>numTasks</i>]) </td>
<td> When called on streams of type (K, V) and (K, W), returns a stream of (K, Seq[V], Seq[W]) tuples. This operation is also called <code>groupWith</code>. </td>
</tr>
-</table>
-
-DStreams also support the following additional transformations:
-
-<table class="table">
<tr>
<td> <b>reduce</b>(<i>func</i>) </td>
<td> Create a new single-element stream by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. </td>
</tr>
</table>
-
-## Windowed Transformations
-Spark streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a <i>windowTime</i>, which represents the width of the window and a <i>slideTime</i>, which represents the frequency during which the window is calculated.
+Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a <i>windowTime</i>, which represents the width of the window and a <i>slideTime</i>, which represents the frequency during which the window is calculated.
<table class="table">
<tr><th style="width:25%">Transformation</th><th>Meaning</th></tr>
@@ -128,7 +138,7 @@ Spark streaming features windowed computations, which allow you to report statis
</table>
-## Output Operators
+## Output Operations
When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined:
<table class="table">
@@ -140,24 +150,159 @@ When an output operator is called, it triggers the computation of a stream. Curr
<tr>
<td> <b>print</b>() </td>
- <td> Prints the contents of this DStream on the driver. At each interval, this will take at most ten elements from the DStream's RDD and print them. </td>
+ <td> Prints first ten elements of every batch of data in a DStream on the driver. </td>
</tr>
<tr>
- <td> <b>saveAsObjectFile</b>(<i>prefix</i>, [<i>suffix</i>]) </td>
+ <td> <b>saveAsObjectFiles</b>(<i>prefix</i>, [<i>suffix</i>]) </td>
<td> Save this DStream's contents as a <code>SequenceFile</code> of serialized objects. The file name at each batch interval is calculated based on <i>prefix</i> and <i>suffix</i>: <i>"prefix-TIME_IN_MS[.suffix]"</i>.
</td>
</tr>
<tr>
- <td> <b>saveAsTextFile</b>(<i>prefix</i>, <i>suffix</i>) </td>
+ <td> <b>saveAsTextFiles</b>(<i>prefix</i>, [<i>suffix</i>]) </td>
<td> Save this DStream's contents as a text files. The file name at each batch interval is calculated based on <i>prefix</i> and <i>suffix</i>: <i>"prefix-TIME_IN_MS[.suffix]"</i>. </td>
</tr>
<tr>
- <td> <b>saveAsHadoopFiles</b>(<i>prefix</i>, <i>suffix</i>) </td>
+ <td> <b>saveAsHadoopFiles</b>(<i>prefix</i>, [<i>suffix</i>]) </td>
<td> Save this DStream's contents as a Hadoop file. The file name at each batch interval is calculated based on <i>prefix</i> and <i>suffix</i>: <i>"prefix-TIME_IN_MS[.suffix]"</i>. </td>
</tr>
</table>
+## DStream Persistence
+Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple DStream operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`.
+
+Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More information on different persistence levels can be found in [Spark Programming Guide](scala-programming-guide.html#rdd-persistence).
+
+# Starting the Streaming computation
+All the above DStream operations are completely lazy, that is, the operations will start executing only after the context is started by using
+{% highlight scala %}
+ssc.start()
+{% endhighlight %}
+
+Conversely, the computation can be stopped by using
+{% highlight scala %}
+ssc.stop()
+{% endhighlight %}
+
+# Example - NetworkWordCount.scala
+A good example to start off is the spark.streaming.examples.NetworkWordCount. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in <Spark repo>/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala.
+
+{% highlight scala %}
+import spark.streaming.{Seconds, StreamingContext}
+import spark.streaming.StreamingContext._
+...
+
+// Create the context and set up a network input stream to receive from a host:port
+val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1))
+val lines = ssc.networkTextStream(args(1), args(2).toInt)
+
+// Split the lines into words, count them, and print some of the counts on the master
+val words = lines.flatMap(_.split(" "))
+val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+wordCounts.print()
+
+// Start the computation
+ssc.start()
+{% endhighlight %}
+
+To run this example on your local machine, you need to first run a Netcat server by using
+
+{% highlight bash %}
+$ nc -lk 9999
+{% endhighlight %}
+
+Then, in a different terminal, you can start NetworkWordCount by using
+
+{% highlight bash %}
+$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999
+{% endhighlight %}
+
+This will make NetworkWordCount connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen.
+
+<table>
+<td>
+{% highlight bash %}
+# TERMINAL 1
+# RUNNING NETCAT
+
+$ nc -lk 9999
+hello world
+
+
+
+
+
+...
+{% endhighlight %}
+</td>
+<td>
+{% highlight bash %}
+# TERMINAL 2: RUNNING NetworkWordCount
+...
+2012-12-31 18:47:10,446 INFO SparkContext: Job finished: run at ThreadPoolExecutor.java:886, took 0.038817 s
+-------------------------------------------
+Time: 1357008430000 ms
+-------------------------------------------
+(hello,1)
+(world,1)
+
+2012-12-31 18:47:10,447 INFO JobManager: Total delay: 0.44700 s for job 8 (execution: 0.44000 s)
+...
+{% endhighlight %}
+</td>
+</table>
+
+
+
+# Performance Tuning
+Getting the best performance of a Spark Streaming application on a cluster requires a bit of tuning. This section explains a number of the parameters and configurations that can tuned to improve the performance of you application. At a high level, you need to consider two things:
+<ol>
+<li>Reducing the processing time of each batch of data by efficiently using cluster resources.</li>
+<li>Setting the right batch size such that the data processing can keep up with the data ingestion.</li>
+</ol>
+
+## Reducing the Processing Time of each Batch
+There are a number of optimizations that can be done in Spark to minimize the processing time of each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section highlights some of the most important ones.
+
+### Level of Parallelism
+Cluster resources maybe underutilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is 8. You can pass the level of parallelism as an argument (see the [`spark.PairDStreamFunctions`](api/streaming/index.html#spark.PairDStreamFunctions) documentation), or set the system property `spark.default.parallelism` to change the default.
+
+### Data Serialization
+The overhead of data serialization can be significant, especially when sub-second batch sizes are to be achieved. There are two aspects to it.
+* Serialization of RDD data in Spark: Please refer to the detailed discussion on data serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default RDDs are persisted as serialized byte arrays to minimize pauses related to GC.
+* Serialization of input data: To ingest external data into Spark, data received as bytes (say, from the network) needs to deserialized from bytes and re-serialized into Spark's serialization format. Hence, the deserialization overhead of input data may be a bottleneck.
+
+### Task Launching Overheads
+If the number of tasks launched per second is high (say, 50 or more per second), then the overhead of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes:
+* Task Serialization: Using Kryo serialization for serializing tasks can reduced the task sizes, and therefore reduce the time taken to send them to the slaves.
+* Execution mode: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details.
+These changes may reduce batch processing time by 100s of milliseconds, thus allowing sub-second batch size to be viable.
+
+## Setting the Right Batch Size
+For a Spark Streaming application running on a cluster to be stable, the processing of the data streams must keep up with the rate of ingestion of the data streams. Depending on the type of computation, the batch size used may have significant impact on the rate of ingestion that can be sustained by the Spark Streaming application on a fixed cluster resources. For example, let us consider the earlier WordCountNetwork example. For a particular data rate, the system may be able to keep up with reporting word counts every 2 seconds (i.e., batch size of 2 seconds), but not every 500 milliseconds.
+
+A good approach to figure out the right batch size for your application is to test it with a conservative batch size (say, 5-10 seconds) and a low data rate. To verify whether the system is able to keep up with data rate, you can check the value of the end-to-end delay experienced by each processed batch (in the Spark master logs, find the line having the phrase "Total delay"). If the delay is maintained to be less than the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the data rate and/or reducing the batch size. Note that momentary increase in the delay due to temporary data rate increases maybe fine as long as the delay reduces back to a low value (i.e., less than batch size).
+
+## 24/7 Operation
+By default, Spark does not forget any of the metadata (RDDs generated, stages processed, etc.). But for a Spark Streaming application to operate 24/7, it is necessary for Spark to do periodic cleanup of it metadata. This can be enabled by setting the Java system property `spark.cleaner.delay` to the number of minutes you want any metadata to persist. For example, setting `spark.cleaner.delay` to 10 would cause Spark periodically cleanup all metadata and persisted RDDs that are older than 10 minutes. Note, that this property needs to be set before the SparkContext is created.
+
+This value is closely tied with any window operation that is being used. Any window operation would require the input data to be persisted in memory for at least the duration of the window. Hence it is necessary to set the delay to at least the value of the largest window operation used in the Spark Streaming application. If this delay is set too low, the application will throw an exception saying so.
+
+## Memory Tuning
+Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, we highlight a few customizations that are strongly recommended to minimize GC related pauses in Spark Streaming applications and achieving more consistent batch processing times.
+
+* <b>Default persistence level of DStreams</b>: Unlike RDDs, the default persistence level of DStreams serializes the data in memory (that is, [StorageLevel.MEMORY_ONLY_SER](api/core/index.html#spark.storage.StorageLevel$) for DStream compared to [StorageLevel.MEMORY_ONLY](api/core/index.html#spark.storage.StorageLevel$) for RDDs). Even though keeping the data serialized incurs a higher serialization overheads, it significantly reduces GC pauses.
+
+* <b>Concurrent garbage collector</b>: Using the concurrent mark-and-sweep GC further minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times.
+
+# Master Fault-tolerance (Alpha)
+TODO
+
+* Checkpointing of DStream graph
+
+* Recovery from master faults
+
+* Current state and future directions \ No newline at end of file
diff --git a/docs/tuning.md b/docs/tuning.md
index f18de8ff3a..9aaa53cd65 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -33,7 +33,7 @@ in your operations) and performance. It provides two serialization libraries:
Java serialization is flexible but often quite slow, and leads to large
serialized formats for many classes.
* [Kryo serialization](http://code.google.com/p/kryo/wiki/V1Documentation): Spark can also use
- the Kryo library (currently just version 1) to serialize objects more quickly. Kryo is significantly
+ the Kryo library (version 2) to serialize objects more quickly. Kryo is significantly
faster and more compact than Java serialization (often as much as 10x), but does not support all
`Serializable` types and requires you to *register* the classes you'll use in the program in advance
for best performance.
@@ -47,6 +47,8 @@ Finally, to register your classes with Kryo, create a public class that extends
`spark.kryo.registrator` system property to point to it, as follows:
{% highlight scala %}
+import com.esotericsoftware.kryo.Kryo
+
class MyRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[MyClass1])
@@ -60,7 +62,7 @@ System.setProperty("spark.kryo.registrator", "mypackage.MyRegistrator")
val sc = new SparkContext(...)
{% endhighlight %}
-The [Kryo documentation](http://code.google.com/p/kryo/wiki/V1Documentation) describes more advanced
+The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced
registration options, such as adding custom serialization code.
If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb`
@@ -147,7 +149,7 @@ the space allocated to the RDD cache to mitigate this.
**Measuring the Impact of GC**
-The first step in GC tuning is to collect statistics on how frequently garbage collection occurs and the amount of
+The first step in GC tuning is to collect statistics on how frequently garbage collection occurs and the amount of
time spent GC. This can be done by adding `-verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps` to your
`SPARK_JAVA_OPTS` environment variable. Next time your Spark job is run, you will see messages printed in the worker's logs
each time a garbage collection occurs. Note these logs will be on your cluster's worker nodes (in the `stdout` files in
@@ -155,15 +157,15 @@ their work directories), *not* on your driver program.
**Cache Size Tuning**
-One important configuration parameter for GC is the amount of memory that should be used for
-caching RDDs. By default, Spark uses 66% of the configured memory (`SPARK_MEM`) to cache RDDs. This means that
+One important configuration parameter for GC is the amount of memory that should be used for
+caching RDDs. By default, Spark uses 66% of the configured memory (`SPARK_MEM`) to cache RDDs. This means that
33% of memory is available for any objects created during task execution.
In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of
-memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call
-`System.setProperty("spark.storage.memoryFraction", "0.5")`. Combined with the use of serialized caching,
-using a smaller cache should be sufficient to mitigate most of the garbage collection problems.
-In case you are interested in further tuning the Java GC, continue reading below.
+memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call
+`System.setProperty("spark.storage.memoryFraction", "0.5")`. Combined with the use of serialized caching,
+using a smaller cache should be sufficient to mitigate most of the garbage collection problems.
+In case you are interested in further tuning the Java GC, continue reading below.
**Advanced GC Tuning**
@@ -172,9 +174,9 @@ To further tune garbage collection, we first need to understand some basic infor
* Java Heap space is divided in to two regions Young and Old. The Young generation is meant to hold short-lived objects
while the Old generation is intended for objects with longer lifetimes.
-* The Young generation is further divided into three regions [Eden, Survivor1, Survivor2].
+* The Young generation is further divided into three regions [Eden, Survivor1, Survivor2].
-* A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects
+* A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects
that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old
enough or Survivor2 is full, it is moved to Old. Finally when Old is close to full, a full GC is invoked.
@@ -186,7 +188,7 @@ temporary objects created during task execution. Some steps which may be useful
before a task completes, it means that there isn't enough memory available for executing tasks.
* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of memory used for caching.
- This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow
+ This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow
down task execution!
* If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You
@@ -195,8 +197,8 @@ temporary objects created during task execution. Some steps which may be useful
up by 4/3 is to account for space used by survivor regions as well.)
* As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using
- the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the
- size of the block. So if we wish to have 3 or 4 tasks worth of working space, and the HDFS block size is 64 MB,
+ the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the
+ size of the block. So if we wish to have 3 or 4 tasks worth of working space, and the HDFS block size is 64 MB,
we can estimate size of Eden to be `4*3*64MB`.
* Monitor how the frequency and time taken by garbage collection changes with the new settings.
diff --git a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala
index e60ce483a3..461929fba2 100644
--- a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala
+++ b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala
@@ -5,7 +5,7 @@ import spark.storage.StorageLevel
import spark.streaming._
/**
- * Produce a streaming count of events received from Flume.
+ * Produces a count of events received from Flume.
*
* This should be used in conjunction with an AvroSink in Flume. It will start
* an Avro server on at the request host:port address and listen for requests.
diff --git a/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala
new file mode 100644
index 0000000000..8530f5c175
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala
@@ -0,0 +1,36 @@
+package spark.streaming.examples
+
+import spark.streaming.{Seconds, StreamingContext}
+import spark.streaming.StreamingContext._
+
+
+/**
+ * Counts words in new text files created in the given directory
+ * Usage: HdfsWordCount <master> <directory>
+ * <master> is the Spark master URL.
+ * <directory> is the directory that Spark Streaming will use to find and read new text files.
+ *
+ * To run this on your local machine on directory `localdir`, run this example
+ * `$ ./run spark.streaming.examples.HdfsWordCount local[2] localdir`
+ * Then create a text file in `localdir` and the words in the file will get counted.
+ */
+object HdfsWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println("Usage: HdfsWordCount <master> <directory>")
+ System.exit(1)
+ }
+
+ // Create the context
+ val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2))
+
+ // Create the FileInputDStream on the directory and use the
+ // stream to count words in new files created
+ val lines = ssc.textFileStream(args(1))
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+ ssc.start()
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
index fe55db6e2c..fe55db6e2c 100644
--- a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
+++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala
index eadda60563..43c01d5db2 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala
+++ b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala
@@ -3,16 +3,27 @@ package spark.streaming.examples
import spark.streaming.{Seconds, StreamingContext}
import spark.streaming.StreamingContext._
-object WordCountNetwork {
+/**
+ * Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ * Usage: NetworkWordCount <master> <hostname> <port>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999`
+ */
+object NetworkWordCount {
def main(args: Array[String]) {
if (args.length < 2) {
- System.err.println("Usage: WordCountNetwork <master> <hostname> <port>\n" +
+ System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" +
"In local mode, <master> should be 'local[n]' with n > 1")
System.exit(1)
}
// Create the context and set the batch size
- val ssc = new StreamingContext(args(0), "WordCountNetwork", Seconds(1))
+ val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1))
// Create a NetworkInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
diff --git a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/spark/streaming/examples/QueueStream.scala
index 2a265d021d..2a265d021d 100644
--- a/streaming/src/main/scala/spark/streaming/examples/QueueStream.scala
+++ b/examples/src/main/scala/spark/streaming/examples/QueueStream.scala
diff --git a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala
new file mode 100644
index 0000000000..2eec777c54
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala
@@ -0,0 +1,46 @@
+package spark.streaming.examples
+
+import spark.util.IntParam
+import spark.storage.StorageLevel
+
+import spark.streaming._
+import spark.streaming.util.RawTextHelper
+
+/**
+ * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited
+ * lines have the word 'the' in them. This is useful for benchmarking purposes. This
+ * will only work with spark.streaming.util.RawTextSender running on all worker nodes
+ * and with Spark using Kryo serialization (set Java property "spark.serializer" to
+ * "spark.KryoSerializer").
+ * Usage: RawNetworkGrep <master> <numStreams> <host> <port> <batchMillis>
+ * <master> is the Spark master URL
+ * <numStream> is the number rawNetworkStreams, which should be same as number
+ * of work nodes in the cluster
+ * <host> is "localhost".
+ * <port> is the port on which RawTextSender is running in the worker nodes.
+ * <batchMillise> is the Spark Streaming batch duration in milliseconds.
+ */
+
+object RawNetworkGrep {
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ System.err.println("Usage: RawNetworkGrep <master> <numStreams> <host> <port> <batchMillis>")
+ System.exit(1)
+ }
+
+ val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args
+
+ // Create the context
+ val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis))
+
+ // Warm up the JVMs on master and slave for JIT compilation to kick in
+ RawTextHelper.warmUp(ssc.sc)
+
+ val rawStreams = (1 to numStreams).map(_ =>
+ ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray
+ val union = ssc.union(rawStreams)
+ union.filter(_.contains("the")).count().foreach(r =>
+ println("Grep count: " + r.collect().mkString))
+ ssc.start()
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala
index 4c6e08bc74..4c6e08bc74 100644
--- a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala
+++ b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala
diff --git a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala
index 68be6b7893..a191321d91 100644
--- a/streaming/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala
+++ b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala
@@ -72,7 +72,7 @@ object PageViewStream {
case "popularUsersSeen" =>
// Look for users in our existing dataset and print it out if we have a match
pageViews.map(view => (view.userID, 1))
- .foreachRDD((rdd, time) => rdd.join(userList)
+ .foreach((rdd, time) => rdd.join(userList)
.map(_._2._2)
.take(10)
.foreach(u => println("Saw user %s at time %s".format(u, time))))
diff --git a/pom.xml b/pom.xml
index 52a4e9d932..b33cee26b8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -185,7 +185,7 @@
<dependency>
<groupId>de.javakaffee</groupId>
<artifactId>kryo-serializers</artifactId>
- <version>0.9</version>
+ <version>0.20</version>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index af32a63df8..d5cda347a4 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -23,7 +23,7 @@ object SparkBuild extends Build {
lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core)
- lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core)
+ lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming)
lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn (core)
@@ -129,7 +129,7 @@ object SparkBuild extends Build {
"org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION,
"asm" % "asm-all" % "3.3.1",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
- "de.javakaffee" % "kryo-serializers" % "0.9",
+ "de.javakaffee" % "kryo-serializers" % "0.20",
"com.typesafe.akka" % "akka-actor" % "2.0.3",
"com.typesafe.akka" % "akka-remote" % "2.0.3",
"com.typesafe.akka" % "akka-slf4j" % "2.0.3",
diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties
index 4c99e450bc..cfb1a390e6 100644
--- a/repl/src/test/resources/log4j.properties
+++ b/repl/src/test/resources/log4j.properties
@@ -1,8 +1,8 @@
-# Set everything to be logged to the console
+# Set everything to be logged to the repl/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-log4j.appender.file.file=spark-tests.log
+log4j.appender.file.file=repl/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
diff --git a/run b/run
index 27506c33e2..2f61cb2a87 100755
--- a/run
+++ b/run
@@ -13,10 +13,6 @@ if [ -e $FWDIR/conf/spark-env.sh ] ; then
. $FWDIR/conf/spark-env.sh
fi
-if [ -e $FWDIR/conf/streaming-env.sh ] ; then
- . $FWDIR/conf/streaming-env.sh
-fi
-
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
if [ `command -v scala` ]; then
RUNNER="scala"
diff --git a/sentences.txt b/sentences.txt
deleted file mode 100644
index fedf96c66e..0000000000
--- a/sentences.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-Hello world!
-What's up?
-There is no cow level
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index 770f7b0cc0..11a7232d7b 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -8,6 +8,7 @@ import org.apache.hadoop.conf.Configuration
import java.io._
+private[streaming]
class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
extends Logging with Serializable {
val master = ssc.sc.master
@@ -30,6 +31,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
/**
* Convenience class to speed up the writing of graph checkpoint to file
*/
+private[streaming]
class CheckpointWriter(checkpointDir: String) extends Logging {
val file = new Path(checkpointDir, "graph")
val conf = new Configuration()
@@ -65,7 +67,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
}
-
+private[streaming]
object CheckpointReader extends Logging {
def read(path: String): Checkpoint = {
@@ -103,6 +105,7 @@ object CheckpointReader extends Logging {
}
}
+private[streaming]
class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) {
override def resolveClass(desc: ObjectStreamClass): Class[_] = {
try {
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 792c129be8..beba9cfd4f 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -1,17 +1,15 @@
package spark.streaming
+import spark.streaming.dstream._
import StreamingContext._
import Time._
-import spark._
-import spark.SparkContext._
-import spark.rdd._
+import spark.{RDD, Logging}
import spark.storage.StorageLevel
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import java.util.concurrent.ArrayBlockingQueue
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import org.apache.hadoop.fs.Path
@@ -21,7 +19,7 @@ import org.apache.hadoop.conf.Configuration
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
* sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]]
* for more details on RDDs). DStreams can either be created from live data (such as, data from
- * HDFS. Kafka or Flume) or it can be generated by transformation existing DStreams using operations
+ * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations
* such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each
* DStream periodically generates a RDD, either from live data or by transforming the RDD generated
* by a parent DStream.
@@ -38,33 +36,28 @@ import org.apache.hadoop.conf.Configuration
* - A function that is used to generate an RDD after each time interval
*/
-case class DStreamCheckpointData(rdds: HashMap[Time, Any])
-
-abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
-extends Serializable with Logging {
+abstract class DStream[T: ClassManifest] (
+ @transient protected[streaming] var ssc: StreamingContext
+ ) extends Serializable with Logging {
initLogging()
- /**
- * ----------------------------------------------
- * Methods that must be implemented by subclasses
- * ----------------------------------------------
- */
+ // =======================================================================
+ // Methods that should be implemented by subclasses of DStream
+ // =======================================================================
- // Time interval at which the DStream generates an RDD
+ /** Time interval after which the DStream generates a RDD */
def slideTime: Time
- // List of parent DStreams on which this DStream depends on
+ /** List of parent DStreams on which this DStream depends on */
def dependencies: List[DStream[_]]
- // Key method that computes RDD for a valid time
+ /** Method that generates a RDD for the given time */
def compute (validTime: Time): Option[RDD[T]]
- /**
- * ---------------------------------------
- * Other general fields and methods of DStream
- * ---------------------------------------
- */
+ // =======================================================================
+ // Methods and fields available on all DStreams
+ // =======================================================================
// RDDs generated, marked as protected[streaming] so that testsuites can access it
@transient
@@ -87,12 +80,15 @@ extends Serializable with Logging {
// Reference to whole DStream graph
protected[streaming] var graph: DStreamGraph = null
- def isInitialized = (zeroTime != null)
+ protected[streaming] def isInitialized = (zeroTime != null)
// Duration for which the DStream requires its parent DStream to remember each RDD created
- def parentRememberDuration = rememberDuration
+ protected[streaming] def parentRememberDuration = rememberDuration
+
+ /** Returns the StreamingContext associated with this DStream */
+ def context() = ssc
- // Set caching level for the RDDs created by this DStream
+ /** Persists the RDDs of this DStream with the given storage level */
def persist(level: StorageLevel): DStream[T] = {
if (this.isInitialized) {
throw new UnsupportedOperationException(
@@ -102,11 +98,16 @@ extends Serializable with Logging {
this
}
+ /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_SER)
-
- // Turn on the default caching level for this RDD
+
+ /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
def cache(): DStream[T] = persist()
+ /**
+ * Enable periodic checkpointing of RDDs of this DStream
+ * @param interval Time interval after which generated RDD will be checkpointed
+ */
def checkpoint(interval: Time): DStream[T] = {
if (isInitialized) {
throw new UnsupportedOperationException(
@@ -188,13 +189,13 @@ extends Serializable with Logging {
val metadataCleanerDelay = spark.util.MetadataCleaner.getDelaySeconds
logInfo("metadataCleanupDelay = " + metadataCleanerDelay)
assert(
- metadataCleanerDelay < 0 || rememberDuration < metadataCleanerDelay * 1000,
+ metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000,
"It seems you are doing some DStream window operation or setting a checkpoint interval " +
"which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " +
"than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" +
"delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " +
"the Java property 'spark.cleaner.delay' to more than " +
- math.ceil(rememberDuration.millis.toDouble / 60000.0).toInt + " minutes."
+ math.ceil(rememberDuration.milliseconds.toDouble / 60000.0).toInt + " minutes."
)
dependencies.foreach(_.validate())
@@ -285,7 +286,7 @@ extends Serializable with Logging {
* Generates a SparkStreaming job for the given time. This is an internal method that
* should not be called directly. This default implementation creates a job
* that materializes the corresponding RDD. Subclasses of DStream may override this
- * (eg. PerRDDForEachDStream).
+ * (eg. ForEachDStream).
*/
protected[streaming] def generateJob(time: Time): Option[Job] = {
getOrCompute(time) match {
@@ -334,20 +335,22 @@ extends Serializable with Logging {
* this method to save custom checkpoint data.
*/
protected[streaming] def updateCheckpointData(currentTime: Time) {
-
logInfo("Updating checkpoint data for time " + currentTime)
// Get the checkpointed RDDs from the generated RDDs
- val newRdds = generatedRDDs.filter(_._2.getCheckpointData() != null)
- .map(x => (x._1, x._2.getCheckpointData()))
- // Make a copy of the existing checkpoint data
+ val newRdds = generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
+ .map(x => (x._1, x._2.getCheckpointFile.get))
+
+ // Make a copy of the existing checkpoint data (checkpointed RDDs)
val oldRdds = checkpointData.rdds.clone()
- // If the new checkpoint has checkpoints then replace existing with the new one
+
+ // If the new checkpoint data has checkpoints then replace existing with the new one
if (newRdds.size > 0) {
checkpointData.rdds.clear()
checkpointData.rdds ++= newRdds
}
- // Make dependencies update their checkpoint data
+
+ // Make parent DStreams update their checkpoint data
dependencies.foreach(_.updateCheckpointData(currentTime))
// TODO: remove this, this is just for debugging
@@ -381,9 +384,7 @@ extends Serializable with Logging {
checkpointData.rdds.foreach {
case(time, data) => {
logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'")
- val rdd = ssc.sc.objectFile[T](data.toString)
- // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData()
- rdd.checkpointFile = data.toString
+ val rdd = ssc.sc.checkpointFile[T](data.toString)
generatedRDDs += ((time, rdd))
}
}
@@ -420,65 +421,96 @@ extends Serializable with Logging {
generatedRDDs = new HashMap[Time, RDD[T]] ()
}
- /**
- * --------------
- * DStream operations
- * --------------
- */
+ // =======================================================================
+ // DStream operations
+ // =======================================================================
+
+ /** Returns a new DStream by applying a function to all elements of this DStream. */
def map[U: ClassManifest](mapFunc: T => U): DStream[U] = {
new MappedDStream(this, ssc.sc.clean(mapFunc))
}
+ /**
+ * Returns a new DStream by applying a function to all elements of this DStream,
+ * and then flattening the results
+ */
def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = {
new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc))
}
+ /** Returns a new DStream containing only the elements that satisfy a predicate. */
def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc)
+ /**
+ * Return a new DStream in which each RDD is generated by applying glom() to each RDD of
+ * this DStream. Applying glom() to an RDD coalesces all elements within each partition into
+ * an array.
+ */
def glom(): DStream[Array[T]] = new GlommedDStream(this)
- def mapPartitions[U: ClassManifest](mapPartFunc: Iterator[T] => Iterator[U]): DStream[U] = {
- new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc))
+ /**
+ * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
+ * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
+ * of the RDD.
+ */
+ def mapPartitions[U: ClassManifest](
+ mapPartFunc: Iterator[T] => Iterator[U],
+ preservePartitioning: Boolean = false
+ ): DStream[U] = {
+ new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc), preservePartitioning)
}
- def reduce(reduceFunc: (T, T) => T): DStream[T] = this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2)
+ /**
+ * Returns a new DStream in which each RDD has a single element generated by reducing each RDD
+ * of this DStream.
+ */
+ def reduce(reduceFunc: (T, T) => T): DStream[T] =
+ this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2)
+ /**
+ * Returns a new DStream in which each RDD has a single element generated by counting each RDD
+ * of this DStream.
+ */
def count(): DStream[Int] = this.map(_ => 1).reduce(_ + _)
-
- def collect(): DStream[Seq[T]] = this.map(x => (null, x)).groupByKey(1).map(_._2)
- def foreach(foreachFunc: T => Unit) {
- val newStream = new PerElementForEachDStream(this, ssc.sc.clean(foreachFunc))
- ssc.registerOutputStream(newStream)
- newStream
- }
-
- def foreachRDD(foreachFunc: RDD[T] => Unit) {
- foreachRDD((r: RDD[T], t: Time) => foreachFunc(r))
+ /**
+ * Applies a function to each RDD in this DStream. This is an output operator, so
+ * this DStream will be registered as an output stream and therefore materialized.
+ */
+ def foreach(foreachFunc: RDD[T] => Unit) {
+ foreach((r: RDD[T], t: Time) => foreachFunc(r))
}
- def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) {
- val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc))
+ /**
+ * Applies a function to each RDD in this DStream. This is an output operator, so
+ * this DStream will be registered as an output stream and therefore materialized.
+ */
+ def foreach(foreachFunc: (RDD[T], Time) => Unit) {
+ val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc))
ssc.registerOutputStream(newStream)
newStream
}
- def transformRDD[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
- transformRDD((r: RDD[T], t: Time) => transformFunc(r))
+ /**
+ * Returns a new DStream in which each RDD is generated by applying a function
+ * on each RDD of this DStream.
+ */
+ def transform[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
+ transform((r: RDD[T], t: Time) => transformFunc(r))
}
- def transformRDD[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
+ /**
+ * Returns a new DStream in which each RDD is generated by applying a function
+ * on each RDD of this DStream.
+ */
+ def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
new TransformedDStream(this, ssc.sc.clean(transformFunc))
}
- def toBlockingQueue() = {
- val queue = new ArrayBlockingQueue[RDD[T]](10000)
- this.foreachRDD(rdd => {
- queue.add(rdd)
- })
- queue
- }
-
+ /**
+ * Prints the first ten elements of each RDD generated in this DStream. This is an output
+ * operator, so this DStream will be registered as an output stream and there materialized.
+ */
def print() {
def foreachFunc = (rdd: RDD[T], time: Time) => {
val first11 = rdd.take(11)
@@ -489,18 +521,42 @@ extends Serializable with Logging {
if (first11.size > 10) println("...")
println()
}
- val newStream = new PerRDDForEachDStream(this, ssc.sc.clean(foreachFunc))
+ val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc))
ssc.registerOutputStream(newStream)
}
+ /**
+ * Return a new DStream which is computed based on windowed batches of this DStream.
+ * The new DStream generates RDDs with the same interval as this DStream.
+ * @param windowTime width of the window; must be a multiple of this DStream's interval.
+ * @return
+ */
def window(windowTime: Time): DStream[T] = window(windowTime, this.slideTime)
+ /**
+ * Return a new DStream which is computed based on windowed batches of this DStream.
+ * @param windowTime duration (i.e., width) of the window;
+ * must be a multiple of this DStream's interval
+ * @param slideTime sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's interval
+ */
def window(windowTime: Time, slideTime: Time): DStream[T] = {
new WindowedDStream(this, windowTime, slideTime)
}
+ /**
+ * Returns a new DStream which computed based on tumbling window on this DStream.
+ * This is equivalent to window(batchTime, batchTime).
+ * @param batchTime tumbling window duration; must be a multiple of this DStream's interval
+ */
def tumble(batchTime: Time): DStream[T] = window(batchTime, batchTime)
+ /**
+ * Returns a new DStream in which each RDD has a single element generated by reducing all
+ * elements in a window over this DStream. windowTime and slideTime are as defined in the
+ * window() operation. This is equivalent to window(windowTime, slideTime).reduce(reduceFunc)
+ */
def reduceByWindow(reduceFunc: (T, T) => T, windowTime: Time, slideTime: Time): DStream[T] = {
this.window(windowTime, slideTime).reduce(reduceFunc)
}
@@ -516,17 +572,31 @@ extends Serializable with Logging {
.map(_._2)
}
+ /**
+ * Returns a new DStream in which each RDD has a single element generated by counting the number
+ * of elements in a window over this DStream. windowTime and slideTime are as defined in the
+ * window() operation. This is equivalent to window(windowTime, slideTime).count()
+ */
def countByWindow(windowTime: Time, slideTime: Time): DStream[Int] = {
this.map(_ => 1).reduceByWindow(_ + _, _ - _, windowTime, slideTime)
}
+ /**
+ * Returns a new DStream by unifying data of another DStream with this DStream.
+ * @param that Another DStream having the same interval (i.e., slideTime) as this DStream.
+ */
def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that))
- def slice(interval: Interval): Seq[RDD[T]] = {
+ /**
+ * Returns all the RDDs defined by the Interval object (both end times included)
+ */
+ protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = {
slice(interval.beginTime, interval.endTime)
}
- // Get all the RDDs between fromTime to toTime (both included)
+ /**
+ * Returns all the RDDs between 'fromTime' to 'toTime' (both included)
+ */
def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = {
val rdds = new ArrayBuffer[RDD[T]]()
var time = toTime.floor(slideTime)
@@ -540,20 +610,26 @@ extends Serializable with Logging {
rdds.toSeq
}
+ /**
+ * Saves each RDD in this DStream as a Sequence file of serialized objects.
+ */
def saveAsObjectFiles(prefix: String, suffix: String = "") {
val saveFunc = (rdd: RDD[T], time: Time) => {
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsObjectFile(file)
}
- this.foreachRDD(saveFunc)
+ this.foreach(saveFunc)
}
+ /**
+ * Saves each RDD in this DStream as at text file, using string representation of elements.
+ */
def saveAsTextFiles(prefix: String, suffix: String = "") {
val saveFunc = (rdd: RDD[T], time: Time) => {
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsTextFile(file)
}
- this.foreachRDD(saveFunc)
+ this.foreach(saveFunc)
}
def register() {
@@ -561,293 +637,6 @@ extends Serializable with Logging {
}
}
+private[streaming]
+case class DStreamCheckpointData(rdds: HashMap[Time, Any])
-abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext)
- extends DStream[T](ssc_) {
-
- override def dependencies = List()
-
- override def slideTime = {
- if (ssc == null) throw new Exception("ssc is null")
- if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null")
- ssc.graph.batchDuration
- }
-
- def start()
-
- def stop()
-}
-
-
-/**
- * TODO
- */
-
-class MappedDStream[T: ClassManifest, U: ClassManifest] (
- parent: DStream[T],
- mapFunc: T => U
- ) extends DStream[U](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[U]] = {
- parent.getOrCompute(validTime).map(_.map[U](mapFunc))
- }
-}
-
-
-/**
- * TODO
- */
-
-class FlatMappedDStream[T: ClassManifest, U: ClassManifest](
- parent: DStream[T],
- flatMapFunc: T => Traversable[U]
- ) extends DStream[U](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[U]] = {
- parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc))
- }
-}
-
-
-/**
- * TODO
- */
-
-class FilteredDStream[T: ClassManifest](
- parent: DStream[T],
- filterFunc: T => Boolean
- ) extends DStream[T](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[T]] = {
- parent.getOrCompute(validTime).map(_.filter(filterFunc))
- }
-}
-
-
-/**
- * TODO
- */
-
-class MapPartitionedDStream[T: ClassManifest, U: ClassManifest](
- parent: DStream[T],
- mapPartFunc: Iterator[T] => Iterator[U]
- ) extends DStream[U](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[U]] = {
- parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc))
- }
-}
-
-
-/**
- * TODO
- */
-
-class GlommedDStream[T: ClassManifest](parent: DStream[T])
- extends DStream[Array[T]](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[Array[T]]] = {
- parent.getOrCompute(validTime).map(_.glom())
- }
-}
-
-
-/**
- * TODO
- */
-
-class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest](
- parent: DStream[(K,V)],
- createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiner: (C, C) => C,
- partitioner: Partitioner
- ) extends DStream [(K,C)] (parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[(K,C)]] = {
- parent.getOrCompute(validTime) match {
- case Some(rdd) =>
- Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner))
- case None => None
- }
- }
-}
-
-
-/**
- * TODO
- */
-
-class MapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
- parent: DStream[(K, V)],
- mapValueFunc: V => U
- ) extends DStream[(K, U)](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[(K, U)]] = {
- parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc))
- }
-}
-
-
-/**
- * TODO
- */
-
-class FlatMapValuesDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
- parent: DStream[(K, V)],
- flatMapValueFunc: V => TraversableOnce[U]
- ) extends DStream[(K, U)](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[(K, U)]] = {
- parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc))
- }
-}
-
-
-
-/**
- * TODO
- */
-
-class UnionDStream[T: ClassManifest](parents: Array[DStream[T]])
- extends DStream[T](parents.head.ssc) {
-
- if (parents.length == 0) {
- throw new IllegalArgumentException("Empty array of parents")
- }
-
- if (parents.map(_.ssc).distinct.size > 1) {
- throw new IllegalArgumentException("Array of parents have different StreamingContexts")
- }
-
- if (parents.map(_.slideTime).distinct.size > 1) {
- throw new IllegalArgumentException("Array of parents have different slide times")
- }
-
- override def dependencies = parents.toList
-
- override def slideTime: Time = parents.head.slideTime
-
- override def compute(validTime: Time): Option[RDD[T]] = {
- val rdds = new ArrayBuffer[RDD[T]]()
- parents.map(_.getOrCompute(validTime)).foreach(_ match {
- case Some(rdd) => rdds += rdd
- case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime)
- })
- if (rdds.size > 0) {
- Some(new UnionRDD(ssc.sc, rdds))
- } else {
- None
- }
- }
-}
-
-
-/**
- * TODO
- */
-
-class PerElementForEachDStream[T: ClassManifest] (
- parent: DStream[T],
- foreachFunc: T => Unit
- ) extends DStream[Unit](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[Unit]] = None
-
- override def generateJob(time: Time): Option[Job] = {
- parent.getOrCompute(time) match {
- case Some(rdd) =>
- val jobFunc = () => {
- val sparkJobFunc = {
- (iterator: Iterator[T]) => iterator.foreach(foreachFunc)
- }
- ssc.sc.runJob(rdd, sparkJobFunc)
- }
- Some(new Job(time, jobFunc))
- case None => None
- }
- }
-}
-
-
-/**
- * TODO
- */
-
-class PerRDDForEachDStream[T: ClassManifest] (
- parent: DStream[T],
- foreachFunc: (RDD[T], Time) => Unit
- ) extends DStream[Unit](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[Unit]] = None
-
- override def generateJob(time: Time): Option[Job] = {
- parent.getOrCompute(time) match {
- case Some(rdd) =>
- val jobFunc = () => {
- foreachFunc(rdd, time)
- }
- Some(new Job(time, jobFunc))
- case None => None
- }
- }
-}
-
-
-/**
- * TODO
- */
-
-class TransformedDStream[T: ClassManifest, U: ClassManifest] (
- parent: DStream[T],
- transformFunc: (RDD[T], Time) => RDD[U]
- ) extends DStream[U](parent.ssc) {
-
- override def dependencies = List(parent)
-
- override def slideTime: Time = parent.slideTime
-
- override def compute(validTime: Time): Option[RDD[U]] = {
- parent.getOrCompute(validTime).map(transformFunc(_, validTime))
- }
-}
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index d0a9ade61d..c72429370e 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -1,5 +1,6 @@
package spark.streaming
+import dstream.InputDStream
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import collection.mutable.ArrayBuffer
import spark.Logging
diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala
deleted file mode 100644
index 05f307a8d1..0000000000
--- a/streaming/src/main/scala/spark/streaming/DataHandler.scala
+++ /dev/null
@@ -1,83 +0,0 @@
-package spark.streaming
-
-import java.util.concurrent.ArrayBlockingQueue
-import scala.collection.mutable.ArrayBuffer
-import spark.Logging
-import spark.streaming.util.{RecurringTimer, SystemClock}
-import spark.storage.StorageLevel
-
-
-/**
- * This is a helper object that manages the data received from the socket. It divides
- * the object received into small batches of 100s of milliseconds, pushes them as
- * blocks into the block manager and reports the block IDs to the network input
- * tracker. It starts two threads, one to periodically start a new batch and prepare
- * the previous batch of as a block, the other to push the blocks into the block
- * manager.
- */
- class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel)
- extends Serializable with Logging {
-
- case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
-
- val clock = new SystemClock()
- val blockInterval = 200L
- val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
- val blockStorageLevel = storageLevel
- val blocksForPushing = new ArrayBlockingQueue[Block](1000)
- val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
-
- var currentBuffer = new ArrayBuffer[T]
-
- def createBlock(blockId: String, iterator: Iterator[T]) : Block = {
- new Block(blockId, iterator)
- }
-
- def start() {
- blockIntervalTimer.start()
- blockPushingThread.start()
- logInfo("Data handler started")
- }
-
- def stop() {
- blockIntervalTimer.stop()
- blockPushingThread.interrupt()
- logInfo("Data handler stopped")
- }
-
- def += (obj: T) {
- currentBuffer += obj
- }
-
- def updateCurrentBuffer(time: Long) {
- try {
- val newBlockBuffer = currentBuffer
- currentBuffer = new ArrayBuffer[T]
- if (newBlockBuffer.size > 0) {
- val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval)
- val newBlock = createBlock(blockId, newBlockBuffer.toIterator)
- blocksForPushing.add(newBlock)
- }
- } catch {
- case ie: InterruptedException =>
- logInfo("Block interval timer thread interrupted")
- case e: Exception =>
- receiver.stop()
- }
- }
-
- def keepPushingBlocks() {
- logInfo("Block pushing thread started")
- try {
- while(true) {
- val block = blocksForPushing.take()
- receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel)
- }
- } catch {
- case ie: InterruptedException =>
- logInfo("Block pushing thread interrupted")
- case e: Exception =>
- receiver.stop()
- }
- }
- } \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala
index ffb7725ac9..fa0b7ce19d 100644
--- a/streaming/src/main/scala/spark/streaming/Interval.scala
+++ b/streaming/src/main/scala/spark/streaming/Interval.scala
@@ -1,5 +1,6 @@
package spark.streaming
+private[streaming]
case class Interval(beginTime: Time, endTime: Time) {
def this(beginMs: Long, endMs: Long) = this(Time(beginMs), new Time(endMs))
diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala
index 0bcb6fd8dc..67bd8388bc 100644
--- a/streaming/src/main/scala/spark/streaming/Job.scala
+++ b/streaming/src/main/scala/spark/streaming/Job.scala
@@ -2,6 +2,7 @@ package spark.streaming
import java.util.concurrent.atomic.AtomicLong
+private[streaming]
class Job(val time: Time, func: () => _) {
val id = Job.getNewId()
def run(): Long = {
@@ -14,6 +15,7 @@ class Job(val time: Time, func: () => _) {
override def toString = "streaming job " + id + " @ " + time
}
+private[streaming]
object Job {
val id = new AtomicLong(0)
diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala
index 9bf9251519..3b910538e0 100644
--- a/streaming/src/main/scala/spark/streaming/JobManager.scala
+++ b/streaming/src/main/scala/spark/streaming/JobManager.scala
@@ -5,6 +5,7 @@ import spark.SparkEnv
import java.util.concurrent.Executors
+private[streaming]
class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging {
class JobHandler(ssc: StreamingContext, job: Job) extends Runnable {
@@ -13,7 +14,7 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging {
try {
val timeTaken = job.run()
logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format(
- (System.currentTimeMillis() - job.time) / 1000.0, job.id, timeTaken / 1000.0))
+ (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, timeTaken / 1000.0))
} catch {
case e: Exception =>
logError("Running " + job + " failed", e)
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
index b421f795ee..a6ab44271f 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
@@ -1,5 +1,7 @@
package spark.streaming
+import spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
+import spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
import spark.Logging
import spark.SparkEnv
@@ -11,10 +13,10 @@ import akka.pattern.ask
import akka.util.duration._
import akka.dispatch._
-trait NetworkInputTrackerMessage
-case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
-case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage
-case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
+private[streaming] sealed trait NetworkInputTrackerMessage
+private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
+private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage
+private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
class NetworkInputTracker(
diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
index 720e63bba0..b0a208e67f 100644
--- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
@@ -1,6 +1,9 @@
package spark.streaming
import spark.streaming.StreamingContext._
+import spark.streaming.dstream.{ReducedWindowedDStream, StateDStream}
+import spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream}
+import spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream}
import spark.{Manifests, RDD, Partitioner, HashPartitioner}
import spark.SparkContext._
@@ -218,13 +221,13 @@ extends Serializable {
def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = {
- new MapValuesDStream[K, V, U](self, mapValuesFunc)
+ new MapValuedDStream[K, V, U](self, mapValuesFunc)
}
def flatMapValues[U: ClassManifest](
flatMapValuesFunc: V => TraversableOnce[U]
): DStream[(K, U)] = {
- new FlatMapValuesDStream[K, V, U](self, flatMapValuesFunc)
+ new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc)
}
def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = {
@@ -281,7 +284,7 @@ extends Serializable {
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf)
}
- self.foreachRDD(saveFunc)
+ self.foreach(saveFunc)
}
def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](
@@ -303,7 +306,7 @@ extends Serializable {
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf)
}
- self.foreachRDD(saveFunc)
+ self.foreach(saveFunc)
}
private def getKeyClass() = implicitly[ClassManifest[K]].erasure
diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala
index 014021be61..eb40affe6d 100644
--- a/streaming/src/main/scala/spark/streaming/Scheduler.scala
+++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala
@@ -4,14 +4,8 @@ import util.{ManualClock, RecurringTimer, Clock}
import spark.SparkEnv
import spark.Logging
-import scala.collection.mutable.HashMap
-
-
-sealed trait SchedulerMessage
-case class InputGenerated(inputName: String, interval: Interval, reference: AnyRef = null) extends SchedulerMessage
-
-class Scheduler(ssc: StreamingContext)
-extends Logging {
+private[streaming]
+class Scheduler(ssc: StreamingContext) extends Logging {
initLogging()
@@ -28,7 +22,7 @@ extends Logging {
val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock")
val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
- val timer = new RecurringTimer(clock, ssc.graph.batchDuration, generateRDDs(_))
+ val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, generateRDDs(_))
def start() {
// If context was started from checkpoint, then restart timer such that
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index ce47bcb2da..215246ba2e 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -1,10 +1,10 @@
package spark.streaming
-import spark.RDD
-import spark.Logging
-import spark.SparkEnv
-import spark.SparkContext
+import spark.streaming.dstream._
+
+import spark.{RDD, Logging, SparkEnv, SparkContext}
import spark.storage.StorageLevel
+import spark.util.MetadataCleaner
import scala.collection.mutable.Queue
@@ -15,10 +15,8 @@ import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
-import org.apache.flume.source.avro.AvroFlumeEvent
import org.apache.hadoop.fs.Path
import java.util.UUID
-import spark.util.MetadataCleaner
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -48,7 +46,7 @@ class StreamingContext private (
this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration)
/**
- * Recreates the StreamingContext from a checkpoint file.
+ * Re-creates a StreamingContext from a checkpoint file.
* @param path Path either to the directory that was specified as the checkpoint directory, or
* to the checkpoint file 'graph' or 'graph.bk'.
*/
@@ -61,7 +59,7 @@ class StreamingContext private (
"both SparkContext and checkpoint as null")
}
- val isCheckpointPresent = (cp_ != null)
+ protected[streaming] val isCheckpointPresent = (cp_ != null)
val sc: SparkContext = {
if (isCheckpointPresent) {
@@ -71,9 +69,9 @@ class StreamingContext private (
}
}
- val env = SparkEnv.get
+ protected[streaming] val env = SparkEnv.get
- val graph: DStreamGraph = {
+ protected[streaming] val graph: DStreamGraph = {
if (isCheckpointPresent) {
cp_.graph.setContext(this)
cp_.graph.restoreCheckpointData()
@@ -86,10 +84,10 @@ class StreamingContext private (
}
}
- private[streaming] val nextNetworkInputStreamId = new AtomicInteger(0)
- private[streaming] var networkInputTracker: NetworkInputTracker = null
+ protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0)
+ protected[streaming] var networkInputTracker: NetworkInputTracker = null
- private[streaming] var checkpointDir: String = {
+ protected[streaming] var checkpointDir: String = {
if (isCheckpointPresent) {
sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(cp_.checkpointDir), true)
cp_.checkpointDir
@@ -98,18 +96,31 @@ class StreamingContext private (
}
}
- private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null
- private[streaming] var receiverJobThread: Thread = null
- private[streaming] var scheduler: Scheduler = null
+ protected[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null
+ protected[streaming] var receiverJobThread: Thread = null
+ protected[streaming] var scheduler: Scheduler = null
+ /**
+ * Sets each DStreams in this context to remember RDDs it generated in the last given duration.
+ * DStreams remember RDDs only for a limited duration of time and releases them for garbage
+ * collection. This method allows the developer to specify how to long to remember the RDDs (
+ * if the developer wishes to query old data outside the DStream computation).
+ * @param duration Minimum duration that each DStream should remember its RDDs
+ */
def remember(duration: Time) {
graph.remember(duration)
}
- def checkpoint(dir: String, interval: Time = null) {
- if (dir != null) {
- sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(dir))
- checkpointDir = dir
+ /**
+ * Sets the context to periodically checkpoint the DStream operations for master
+ * fault-tolerance. By default, the graph will be checkpointed every batch interval.
+ * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
+ * @param interval checkpoint interval
+ */
+ def checkpoint(directory: String, interval: Time = null) {
+ if (directory != null) {
+ sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory))
+ checkpointDir = directory
checkpointInterval = interval
} else {
checkpointDir = null
@@ -117,16 +128,15 @@ class StreamingContext private (
}
}
- private[streaming] def getInitialCheckpoint(): Checkpoint = {
+ protected[streaming] def getInitialCheckpoint(): Checkpoint = {
if (isCheckpointPresent) cp_ else null
}
- private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
+ protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
- /**
+ /**
* Create an input stream that pulls messages form a Kafka Broker.
- *
- * @param host Zookeper hostname.
+ * @param hostname Zookeper hostname.
* @param port Zookeper port.
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
@@ -144,10 +154,19 @@ class StreamingContext private (
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
): DStream[T] = {
val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel)
- graph.addInputStream(inputStream)
+ registerInputStream(inputStream)
inputStream
}
+ /**
+ * Create a input stream from network source hostname:port. Data is received using
+ * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited
+ * lines.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @param storageLevel Storage level to use for storing the received objects
+ * (default: StorageLevel.MEMORY_AND_DISK_SER_2)
+ */
def networkTextStream(
hostname: String,
port: Int,
@@ -156,6 +175,16 @@ class StreamingContext private (
networkStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel)
}
+ /**
+ * Create a input stream from network source hostname:port. Data is received using
+ * a TCP socket and the receive bytes it interepreted as object using the given
+ * converter.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @param converter Function to convert the byte stream to objects
+ * @param storageLevel Storage level to use for storing the received objects
+ * @tparam T Type of the objects received (after converting bytes to objects)
+ */
def networkStream[T: ClassManifest](
hostname: String,
port: Int,
@@ -163,51 +192,103 @@ class StreamingContext private (
storageLevel: StorageLevel
): DStream[T] = {
val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel)
- graph.addInputStream(inputStream)
+ registerInputStream(inputStream)
inputStream
}
+ /**
+ * Creates a input stream from a Flume source.
+ * @param hostname Hostname of the slave machine to which the flume data will be sent
+ * @param port Port of the slave machine to which the flume data will be sent
+ * @param storageLevel Storage level to use for storing the received objects
+ */
def flumeStream (
- hostname: String,
- port: Int,
- storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2): DStream[SparkFlumeEvent] = {
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
+ ): DStream[SparkFlumeEvent] = {
val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel)
- graph.addInputStream(inputStream)
+ registerInputStream(inputStream)
inputStream
}
-
+ /**
+ * Create a input stream from network source hostname:port, where data is received
+ * as serialized blocks (serialized using the Spark's serializer) that can be directly
+ * pushed into the block manager without deserializing them. This is the most efficient
+ * way to receive data.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @param storageLevel Storage level to use for storing the received objects
+ * @tparam T Type of the objects in the received blocks
+ */
def rawNetworkStream[T: ClassManifest](
hostname: String,
port: Int,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
): DStream[T] = {
val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel)
- graph.addInputStream(inputStream)
+ registerInputStream(inputStream)
inputStream
}
/**
- * This function creates a input stream that monitors a Hadoop-compatible filesystem
- * for new files and executes the necessary processing on them.
+ * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * for new files and reads them using the given key-value types and input format.
+ * File names starting with . are ignored.
+ * @param directory HDFS directory to monitor for new file
+ * @tparam K Key type for reading HDFS file
+ * @tparam V Value type for reading HDFS file
+ * @tparam F Input format for reading HDFS file
*/
def fileStream[
K: ClassManifest,
V: ClassManifest,
F <: NewInputFormat[K, V]: ClassManifest
- ](directory: String): DStream[(K, V)] = {
+ ] (directory: String): DStream[(K, V)] = {
val inputStream = new FileInputDStream[K, V, F](this, directory)
- graph.addInputStream(inputStream)
+ registerInputStream(inputStream)
inputStream
}
+ /**
+ * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * for new files and reads them using the given key-value types and input format.
+ * @param directory HDFS directory to monitor for new file
+ * @param filter Function to filter paths to process
+ * @param newFilesOnly Should process only new files and ignore existing files in the directory
+ * @tparam K Key type for reading HDFS file
+ * @tparam V Value type for reading HDFS file
+ * @tparam F Input format for reading HDFS file
+ */
+ def fileStream[
+ K: ClassManifest,
+ V: ClassManifest,
+ F <: NewInputFormat[K, V]: ClassManifest
+ ] (directory: String, filter: Path => Boolean, newFilesOnly: Boolean): DStream[(K, V)] = {
+ val inputStream = new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly)
+ registerInputStream(inputStream)
+ inputStream
+ }
+
+
+ /**
+ * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * for new files and reads them as text files (using key as LongWritable, value
+ * as Text and input format as TextInputFormat). File names starting with . are ignored.
+ * @param directory HDFS directory to monitor for new file
+ */
def textFileStream(directory: String): DStream[String] = {
fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString)
}
/**
- * This function create a input stream from an queue of RDDs. In each batch,
- * it will process either one or all of the RDDs returned by the queue
+ * Creates a input stream from an queue of RDDs. In each batch,
+ * it will process either one or all of the RDDs returned by the queue.
+ * @param queue Queue of RDDs
+ * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
+ * @param defaultRDD Default RDD is returned by the DStream when the queue is empty
+ * @tparam T Type of objects in the RDD
*/
def queueStream[T: ClassManifest](
queue: Queue[RDD[T]],
@@ -215,38 +296,33 @@ class StreamingContext private (
defaultRDD: RDD[T] = null
): DStream[T] = {
val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD)
- graph.addInputStream(inputStream)
- inputStream
- }
-
- def queueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = {
- val queue = new Queue[RDD[T]]
- val inputStream = queueStream(queue, true, null)
- queue ++= array
+ registerInputStream(inputStream)
inputStream
}
+ /**
+ * Create a unified DStream from multiple DStreams of the same type and same interval
+ */
def union[T: ClassManifest](streams: Seq[DStream[T]]): DStream[T] = {
new UnionDStream[T](streams.toArray)
}
/**
- * This function registers a InputDStream as an input stream that will be
- * started (InputDStream.start() called) to get the input data streams.
+ * Registers an input stream that will be started (InputDStream.start() called) to get the
+ * input data.
*/
def registerInputStream(inputStream: InputDStream[_]) {
graph.addInputStream(inputStream)
}
/**
- * This function registers a DStream as an output stream that will be
- * computed every interval.
+ * Registers an output stream that will be computed every interval
*/
def registerOutputStream(outputStream: DStream[_]) {
graph.addOutputStream(outputStream)
}
- def validate() {
+ protected def validate() {
assert(graph != null, "Graph is null")
graph.validate()
@@ -258,7 +334,7 @@ class StreamingContext private (
}
/**
- * This function starts the execution of the streams.
+ * Starts the execution of the streams.
*/
def start() {
if (checkpointDir != null && checkpointInterval == null && graph != null) {
@@ -286,7 +362,7 @@ class StreamingContext private (
}
/**
- * This function stops the execution of the streams.
+ * Sstops the execution of the streams.
*/
def stop() {
try {
@@ -304,7 +380,11 @@ class StreamingContext private (
object StreamingContext {
- def createNewSparkContext(master: String, frameworkName: String): SparkContext = {
+ implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = {
+ new PairDStreamFunctions[K, V](stream)
+ }
+
+ protected[streaming] def createNewSparkContext(master: String, frameworkName: String): SparkContext = {
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second interval.
@@ -314,13 +394,9 @@ object StreamingContext {
new SparkContext(master, frameworkName)
}
- implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = {
- new PairDStreamFunctions[K, V](stream)
- }
-
- def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
+ protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
if (prefix == null) {
- time.millis.toString
+ time.milliseconds.toString
} else if (suffix == null || suffix.length ==0) {
prefix + "-" + time.milliseconds
} else {
@@ -328,7 +404,7 @@ object StreamingContext {
}
}
- def getSparkCheckpointDir(sscCheckpointDir: String): String = {
+ protected[streaming] def getSparkCheckpointDir(sscCheckpointDir: String): String = {
new Path(sscCheckpointDir, UUID.randomUUID.toString).toString
}
}
diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala
index 480d292d7c..3c6fd5d967 100644
--- a/streaming/src/main/scala/spark/streaming/Time.scala
+++ b/streaming/src/main/scala/spark/streaming/Time.scala
@@ -1,11 +1,18 @@
package spark.streaming
-case class Time(millis: Long) {
+/**
+ * This is a simple class that represents time. Internally, it represents time as UTC.
+ * The recommended way to create instances of Time is to use helper objects
+ * [[spark.streaming.Milliseconds]], [[spark.streaming.Seconds]], and [[spark.streaming.Minutes]].
+ * @param millis Time in UTC.
+ */
+
+case class Time(private val millis: Long) {
def < (that: Time): Boolean = (this.millis < that.millis)
-
+
def <= (that: Time): Boolean = (this.millis <= that.millis)
-
+
def > (that: Time): Boolean = (this.millis > that.millis)
def >= (that: Time): Boolean = (this.millis >= that.millis)
@@ -15,7 +22,9 @@ case class Time(millis: Long) {
def - (that: Time): Time = Time(millis - that.millis)
def * (times: Int): Time = Time(millis * times)
-
+
+ def / (that: Time): Long = millis / that.millis
+
def floor(that: Time): Time = {
val t = that.millis
val m = math.floor(this.millis / t).toLong
@@ -38,23 +47,33 @@ case class Time(millis: Long) {
def milliseconds: Long = millis
}
-object Time {
+private[streaming] object Time {
val zero = Time(0)
implicit def toTime(long: Long) = Time(long)
-
- implicit def toLong(time: Time) = time.milliseconds
}
+/**
+ * Helper object that creates instance of [[spark.streaming.Time]] representing
+ * a given number of milliseconds.
+ */
object Milliseconds {
def apply(milliseconds: Long) = Time(milliseconds)
}
+/**
+ * Helper object that creates instance of [[spark.streaming.Time]] representing
+ * a given number of seconds.
+ */
object Seconds {
def apply(seconds: Long) = Time(seconds * 1000)
-}
+}
-object Minutes {
+/**
+ * Helper object that creates instance of [[spark.streaming.Time]] representing
+ * a given number of minutes.
+ */
+object Minutes {
def apply(minutes: Long) = Time(minutes * 60000)
}
diff --git a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala
index 61d088eddb..bc23d423d3 100644
--- a/streaming/src/main/scala/spark/streaming/CoGroupedDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala
@@ -1,8 +1,10 @@
-package spark.streaming
+package spark.streaming.dstream
import spark.{RDD, Partitioner}
import spark.rdd.CoGroupedRDD
+import spark.streaming.{Time, DStream}
+private[streaming]
class CoGroupedDStream[K : ClassManifest](
parents: Seq[DStream[(_, _)]],
partitioner: Partitioner
diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala
index 80150708fd..41c3af4694 100644
--- a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala
@@ -1,6 +1,7 @@
-package spark.streaming
+package spark.streaming.dstream
import spark.RDD
+import spark.streaming.{Time, StreamingContext}
/**
* An input stream that always returns the same RDD on each timestep. Useful for testing.
diff --git a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala
index 88856364d2..1e6ad84b44 100644
--- a/streaming/src/main/scala/spark/streaming/FileInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala
@@ -1,7 +1,8 @@
-package spark.streaming
+package spark.streaming.dstream
import spark.RDD
import spark.rdd.UnionRDD
+import spark.streaming.{StreamingContext, Time}
import org.apache.hadoop.fs.{FileSystem, Path, PathFilter}
import org.apache.hadoop.conf.Configuration
@@ -9,11 +10,11 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import scala.collection.mutable.HashSet
-
+private[streaming]
class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest](
@transient ssc_ : StreamingContext,
directory: String,
- filter: PathFilter = FileInputDStream.defaultPathFilter,
+ filter: Path => Boolean = FileInputDStream.defaultFilter,
newFilesOnly: Boolean = true)
extends InputDStream[(K, V)](ssc_) {
@@ -59,7 +60,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K
val latestModTimeFiles = new HashSet[String]()
def accept(path: Path): Boolean = {
- if (!filter.accept(path)) {
+ if (!filter(path)) {
return false
} else {
val modTime = fs.getFileStatus(path).getModificationTime()
@@ -94,16 +95,8 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K
}
}
+private[streaming]
object FileInputDStream {
- val defaultPathFilter = new PathFilter with Serializable {
- def accept(path: Path): Boolean = {
- val file = path.getName()
- if (file.startsWith(".") || file.endsWith("_tmp")) {
- return false
- } else {
- return true
- }
- }
- }
+ def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".")
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala
new file mode 100644
index 0000000000..1cbb4d536e
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala
@@ -0,0 +1,21 @@
+package spark.streaming.dstream
+
+import spark.streaming.{DStream, Time}
+import spark.RDD
+
+private[streaming]
+class FilteredDStream[T: ClassManifest](
+ parent: DStream[T],
+ filterFunc: T => Boolean
+ ) extends DStream[T](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ parent.getOrCompute(validTime).map(_.filter(filterFunc))
+ }
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala
new file mode 100644
index 0000000000..11ed8cf317
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala
@@ -0,0 +1,20 @@
+package spark.streaming.dstream
+
+import spark.streaming.{DStream, Time}
+import spark.RDD
+import spark.SparkContext._
+
+private[streaming]
+class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+ parent: DStream[(K, V)],
+ flatMapValueFunc: V => TraversableOnce[U]
+ ) extends DStream[(K, U)](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K, U)]] = {
+ parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc))
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala
new file mode 100644
index 0000000000..a13b4c9ff9
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala
@@ -0,0 +1,20 @@
+package spark.streaming.dstream
+
+import spark.streaming.{DStream, Time}
+import spark.RDD
+
+private[streaming]
+class FlatMappedDStream[T: ClassManifest, U: ClassManifest](
+ parent: DStream[T],
+ flatMapFunc: T => Traversable[U]
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc))
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala
index 2959ce4540..ca70e72e56 100644
--- a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala
@@ -1,17 +1,23 @@
-package spark.streaming
+package spark.streaming.dstream
-import java.io.{ObjectInput, ObjectOutput, Externalizable}
+import spark.streaming.StreamingContext
+
+import spark.Utils
import spark.storage.StorageLevel
+
import org.apache.flume.source.avro.AvroSourceProtocol
import org.apache.flume.source.avro.AvroFlumeEvent
import org.apache.flume.source.avro.Status
import org.apache.avro.ipc.specific.SpecificResponder
import org.apache.avro.ipc.NettyServer
+
+import scala.collection.JavaConversions._
+
import java.net.InetSocketAddress
-import collection.JavaConversions._
-import spark.Utils
+import java.io.{ObjectInput, ObjectOutput, Externalizable}
import java.nio.ByteBuffer
+private[streaming]
class FlumeInputDStream[T: ClassManifest](
@transient ssc_ : StreamingContext,
host: String,
@@ -79,7 +85,7 @@ class SparkFlumeEvent() extends Externalizable {
}
}
-object SparkFlumeEvent {
+private[streaming] object SparkFlumeEvent {
def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = {
val event = new SparkFlumeEvent
event.event = in
@@ -88,41 +94,43 @@ object SparkFlumeEvent {
}
/** A simple server that implements Flume's Avro protocol. */
+private[streaming]
class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol {
override def append(event : AvroFlumeEvent) : Status = {
- receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)
+ receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event)
Status.OK
}
override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = {
events.foreach (event =>
- receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event))
+ receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event))
Status.OK
}
}
/** A NetworkReceiver which listens for events using the
* Flume Avro interface.*/
+private[streaming]
class FlumeReceiver(
- streamId: Int,
- host: String,
- port: Int,
- storageLevel: StorageLevel
- ) extends NetworkReceiver[SparkFlumeEvent](streamId) {
+ streamId: Int,
+ host: String,
+ port: Int,
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[SparkFlumeEvent](streamId) {
- lazy val dataHandler = new DataHandler(this, storageLevel)
+ lazy val blockGenerator = new BlockGenerator(storageLevel)
protected override def onStart() {
val responder = new SpecificResponder(
classOf[AvroSourceProtocol], new FlumeEventServer(this));
val server = new NettyServer(responder, new InetSocketAddress(host, port));
- dataHandler.start()
+ blockGenerator.start()
server.start()
logInfo("Flume receiver started")
}
protected override def onStop() {
- dataHandler.stop()
+ blockGenerator.stop()
logInfo("Flume receiver stopped")
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala
new file mode 100644
index 0000000000..41c629a225
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala
@@ -0,0 +1,28 @@
+package spark.streaming.dstream
+
+import spark.RDD
+import spark.streaming.{DStream, Job, Time}
+
+private[streaming]
+class ForEachDStream[T: ClassManifest] (
+ parent: DStream[T],
+ foreachFunc: (RDD[T], Time) => Unit
+ ) extends DStream[Unit](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[Unit]] = None
+
+ override def generateJob(time: Time): Option[Job] = {
+ parent.getOrCompute(time) match {
+ case Some(rdd) =>
+ val jobFunc = () => {
+ foreachFunc(rdd, time)
+ }
+ Some(new Job(time, jobFunc))
+ case None => None
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala
new file mode 100644
index 0000000000..92ea503cae
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala
@@ -0,0 +1,17 @@
+package spark.streaming.dstream
+
+import spark.streaming.{DStream, Time}
+import spark.RDD
+
+private[streaming]
+class GlommedDStream[T: ClassManifest](parent: DStream[T])
+ extends DStream[Array[T]](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[Array[T]]] = {
+ parent.getOrCompute(validTime).map(_.glom())
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala
new file mode 100644
index 0000000000..4959c66b06
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala
@@ -0,0 +1,19 @@
+package spark.streaming.dstream
+
+import spark.streaming.{StreamingContext, DStream}
+
+abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext)
+ extends DStream[T](ssc_) {
+
+ override def dependencies = List()
+
+ override def slideTime = {
+ if (ssc == null) throw new Exception("ssc is null")
+ if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null")
+ ssc.graph.batchDuration
+ }
+
+ def start()
+
+ def stop()
+}
diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
index 7c642d4802..25988a2ce7 100644
--- a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
@@ -1,30 +1,36 @@
-package spark.streaming
+package spark.streaming.dstream
+
+import spark.Logging
+import spark.storage.StorageLevel
+import spark.streaming.{Time, DStreamCheckpointData, StreamingContext}
import java.util.Properties
import java.util.concurrent.Executors
+
import kafka.consumer._
import kafka.message.{Message, MessageSet, MessageAndMetadata}
import kafka.serializer.StringDecoder
import kafka.utils.{Utils, ZKGroupTopicDirs}
import kafka.utils.ZkUtils._
+
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
-import spark._
-import spark.RDD
-import spark.storage.StorageLevel
+
// Key for a specific Kafka Partition: (broker, topic, group, part)
case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
// NOT USED - Originally intended for fault-tolerance
// Metadata for a Kafka Stream that it sent to the Master
+private[streaming]
case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long])
// NOT USED - Originally intended for fault-tolerance
// Checkpoint data specific to a KafkaInputDstream
-case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any],
+private[streaming]
+case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any],
savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds)
/**
- * Input stream that pulls messages form a Kafka Broker.
+ * Input stream that pulls messages from a Kafka Broker.
*
* @param host Zookeper hostname.
* @param port Zookeper port.
@@ -35,6 +41,7 @@ case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any],
* By default the value is pulled from zookeper.
* @param storageLevel RDD storage level.
*/
+private[streaming]
class KafkaInputDStream[T: ClassManifest](
@transient ssc_ : StreamingContext,
host: String,
@@ -94,6 +101,7 @@ class KafkaInputDStream[T: ClassManifest](
}
}
+private[streaming]
class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String,
topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) {
@@ -102,20 +110,19 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String,
val ZK_TIMEOUT = 10000
// Handles pushing data into the BlockManager
- lazy protected val dataHandler = new DataHandler(this, storageLevel)
+ lazy protected val blockGenerator = new BlockGenerator(storageLevel)
// Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset
lazy val offsets = HashMap[KafkaPartitionKey, Long]()
// Connection to Kafka
var consumerConnector : ZookeeperConsumerConnector = null
def onStop() {
- dataHandler.stop()
+ blockGenerator.stop()
}
def onStart() {
- // Starting the DataHandler that buffers blocks and pushes them into them BlockManager
- dataHandler.start()
+ blockGenerator.start()
// In case we are using multiple Threads to handle Kafka Messages
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
@@ -163,8 +170,8 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String,
private class MessageHandler(stream: KafkaStream[String]) extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
- stream.takeWhile { msgAndMetadata =>
- dataHandler += msgAndMetadata.message
+ stream.takeWhile { msgAndMetadata =>
+ blockGenerator += msgAndMetadata.message
// Updating the offet. The key is (broker, topic, group, partition).
val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic,
@@ -181,7 +188,7 @@ class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String,
// NOT USED - Originally intended for fault-tolerance
// class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel)
- // extends DataHandler[Any](receiver, storageLevel) {
+ // extends BufferingBlockCreator[Any](receiver, storageLevel) {
// override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = {
// // Creates a new Block with Kafka-specific Metadata
diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala
new file mode 100644
index 0000000000..daf78c6893
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala
@@ -0,0 +1,21 @@
+package spark.streaming.dstream
+
+import spark.streaming.{DStream, Time}
+import spark.RDD
+
+private[streaming]
+class MapPartitionedDStream[T: ClassManifest, U: ClassManifest](
+ parent: DStream[T],
+ mapPartFunc: Iterator[T] => Iterator[U],
+ preservePartitioning: Boolean
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning))
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala
new file mode 100644
index 0000000000..689caeef0e
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala
@@ -0,0 +1,21 @@
+package spark.streaming.dstream
+
+import spark.streaming.{DStream, Time}
+import spark.RDD
+import spark.SparkContext._
+
+private[streaming]
+class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+ parent: DStream[(K, V)],
+ mapValueFunc: V => U
+ ) extends DStream[(K, U)](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K, U)]] = {
+ parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc))
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala
new file mode 100644
index 0000000000..786b9966f2
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala
@@ -0,0 +1,20 @@
+package spark.streaming.dstream
+
+import spark.streaming.{DStream, Time}
+import spark.RDD
+
+private[streaming]
+class MappedDStream[T: ClassManifest, U: ClassManifest] (
+ parent: DStream[T],
+ mapFunc: T => U
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(_.map[U](mapFunc))
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
index 4e4e9fc942..18e62a0e33 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
@@ -1,18 +1,21 @@
-package spark.streaming
+package spark.streaming.dstream
-import scala.collection.mutable.ArrayBuffer
+import spark.streaming.{Time, StreamingContext, AddBlocks, RegisterReceiver, DeregisterReceiver}
import spark.{Logging, SparkEnv, RDD}
import spark.rdd.BlockRDD
-import spark.streaming.util.{RecurringTimer, SystemClock}
import spark.storage.StorageLevel
+import scala.collection.mutable.ArrayBuffer
+
import java.nio.ByteBuffer
import akka.actor.{Props, Actor}
import akka.pattern.ask
import akka.dispatch.Await
import akka.util.duration._
+import spark.streaming.util.{RecurringTimer, SystemClock}
+import java.util.concurrent.ArrayBlockingQueue
abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext)
extends InputDStream[T](ssc_) {
@@ -40,10 +43,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
}
-sealed trait NetworkReceiverMessage
-case class StopReceiver(msg: String) extends NetworkReceiverMessage
-case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage
-case class ReportError(msg: String) extends NetworkReceiverMessage
+private[streaming] sealed trait NetworkReceiverMessage
+private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage
+private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage
+private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage
abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging {
@@ -153,4 +156,77 @@ abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Seri
tracker ! DeregisterReceiver(streamId, msg)
}
}
+
+ /**
+ * Batches objects created by a [[spark.streaming.NetworkReceiver]] and puts them into
+ * appropriately named blocks at regular intervals. This class starts two threads,
+ * one to periodically start a new batch and prepare the previous batch of as a block,
+ * the other to push the blocks into the block manager.
+ */
+ class BlockGenerator(storageLevel: StorageLevel)
+ extends Serializable with Logging {
+
+ case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
+
+ val clock = new SystemClock()
+ val blockInterval = 200L
+ val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
+ val blockStorageLevel = storageLevel
+ val blocksForPushing = new ArrayBlockingQueue[Block](1000)
+ val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
+
+ var currentBuffer = new ArrayBuffer[T]
+
+ def start() {
+ blockIntervalTimer.start()
+ blockPushingThread.start()
+ logInfo("Data handler started")
+ }
+
+ def stop() {
+ blockIntervalTimer.stop()
+ blockPushingThread.interrupt()
+ logInfo("Data handler stopped")
+ }
+
+ def += (obj: T) {
+ currentBuffer += obj
+ }
+
+ private def createBlock(blockId: String, iterator: Iterator[T]) : Block = {
+ new Block(blockId, iterator)
+ }
+
+ private def updateCurrentBuffer(time: Long) {
+ try {
+ val newBlockBuffer = currentBuffer
+ currentBuffer = new ArrayBuffer[T]
+ if (newBlockBuffer.size > 0) {
+ val blockId = "input-" + NetworkReceiver.this.streamId + "- " + (time - blockInterval)
+ val newBlock = createBlock(blockId, newBlockBuffer.toIterator)
+ blocksForPushing.add(newBlock)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block interval timer thread interrupted")
+ case e: Exception =>
+ NetworkReceiver.this.stop()
+ }
+ }
+
+ private def keepPushingBlocks() {
+ logInfo("Block pushing thread started")
+ try {
+ while(true) {
+ val block = blocksForPushing.take()
+ NetworkReceiver.this.pushBlock(block.id, block.iterator, block.metadata, storageLevel)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block pushing thread interrupted")
+ case e: Exception =>
+ NetworkReceiver.this.stop()
+ }
+ }
+ }
}
diff --git a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala
index bb86e51932..024bf3bea4 100644
--- a/streaming/src/main/scala/spark/streaming/QueueInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala
@@ -1,10 +1,11 @@
-package spark.streaming
+package spark.streaming.dstream
import spark.RDD
import spark.rdd.UnionRDD
import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
+import spark.streaming.{Time, StreamingContext}
class QueueInputDStream[T: ClassManifest](
@transient ssc: StreamingContext,
diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
index 6acaa9aab1..aa2f31cea8 100644
--- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
@@ -1,12 +1,15 @@
-package spark.streaming
+package spark.streaming.dstream
+
+import spark.{DaemonThread, Logging}
+import spark.storage.StorageLevel
+import spark.streaming.StreamingContext
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.{ReadableByteChannel, SocketChannel}
import java.io.EOFException
import java.util.concurrent.ArrayBlockingQueue
-import spark._
-import spark.storage.StorageLevel
+
/**
* An input stream that reads blocks of serialized objects from a given network address.
@@ -14,6 +17,7 @@ import spark.storage.StorageLevel
* data into Spark Streaming, though it requires the sender to batch data and serialize it
* in the format that the system is configured with.
*/
+private[streaming]
class RawInputDStream[T: ClassManifest](
@transient ssc_ : StreamingContext,
host: String,
@@ -26,6 +30,7 @@ class RawInputDStream[T: ClassManifest](
}
}
+private[streaming]
class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: StorageLevel)
extends NetworkReceiver[Any](streamId) {
diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala
index f63a9e0011..d289ed2a3f 100644
--- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -1,17 +1,17 @@
-package spark.streaming
+package spark.streaming.dstream
import spark.streaming.StreamingContext._
import spark.RDD
-import spark.rdd.UnionRDD
import spark.rdd.CoGroupedRDD
import spark.Partitioner
import spark.SparkContext._
import spark.storage.StorageLevel
import scala.collection.mutable.ArrayBuffer
-import collection.SeqProxy
+import spark.streaming.{Interval, Time, DStream}
+private[streaming]
class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
parent: DStream[(K, V)],
reduceFunc: (V, V) => V,
diff --git a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala
new file mode 100644
index 0000000000..6854bbe665
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala
@@ -0,0 +1,27 @@
+package spark.streaming.dstream
+
+import spark.{RDD, Partitioner}
+import spark.SparkContext._
+import spark.streaming.{DStream, Time}
+
+private[streaming]
+class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest](
+ parent: DStream[(K,V)],
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiner: (C, C) => C,
+ partitioner: Partitioner
+ ) extends DStream [(K,C)] (parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[(K,C)]] = {
+ parent.getOrCompute(validTime) match {
+ case Some(rdd) =>
+ Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner))
+ case None => None
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala
index a9e37c0ff0..8e4b20ea4c 100644
--- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala
@@ -1,15 +1,12 @@
-package spark.streaming
+package spark.streaming.dstream
-import spark.streaming.util.{RecurringTimer, SystemClock}
+import spark.streaming.StreamingContext
import spark.storage.StorageLevel
import java.io._
import java.net.Socket
-import java.util.concurrent.ArrayBlockingQueue
-
-import scala.collection.mutable.ArrayBuffer
-import scala.Serializable
+private[streaming]
class SocketInputDStream[T: ClassManifest](
@transient ssc_ : StreamingContext,
host: String,
@@ -23,7 +20,7 @@ class SocketInputDStream[T: ClassManifest](
}
}
-
+private[streaming]
class SocketReceiver[T: ClassManifest](
streamId: Int,
host: String,
@@ -32,7 +29,7 @@ class SocketReceiver[T: ClassManifest](
storageLevel: StorageLevel
) extends NetworkReceiver[T](streamId) {
- lazy protected val dataHandler = new DataHandler(this, storageLevel)
+ lazy protected val blockGenerator = new BlockGenerator(storageLevel)
override def getLocationPreference = None
@@ -40,21 +37,21 @@ class SocketReceiver[T: ClassManifest](
logInfo("Connecting to " + host + ":" + port)
val socket = new Socket(host, port)
logInfo("Connected to " + host + ":" + port)
- dataHandler.start()
+ blockGenerator.start()
val iterator = bytesToObjects(socket.getInputStream())
while(iterator.hasNext) {
val obj = iterator.next
- dataHandler += obj
+ blockGenerator += obj
}
}
protected def onStop() {
- dataHandler.stop()
+ blockGenerator.stop()
}
}
-
+private[streaming]
object SocketReceiver {
/**
diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala
index b7e4c1c30c..175b3060c1 100644
--- a/streaming/src/main/scala/spark/streaming/StateDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala
@@ -1,12 +1,12 @@
-package spark.streaming
+package spark.streaming.dstream
import spark.RDD
-import spark.rdd.BlockRDD
import spark.Partitioner
-import spark.rdd.MapPartitionsRDD
import spark.SparkContext._
import spark.storage.StorageLevel
+import spark.streaming.{Time, DStream}
+private[streaming]
class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest](
parent: DStream[(K, V)],
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
diff --git a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala
new file mode 100644
index 0000000000..0337579514
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala
@@ -0,0 +1,19 @@
+package spark.streaming.dstream
+
+import spark.RDD
+import spark.streaming.{DStream, Time}
+
+private[streaming]
+class TransformedDStream[T: ClassManifest, U: ClassManifest] (
+ parent: DStream[T],
+ transformFunc: (RDD[T], Time) => RDD[U]
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideTime: Time = parent.slideTime
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(transformFunc(_, validTime))
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala
new file mode 100644
index 0000000000..3bf4c2ecea
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala
@@ -0,0 +1,40 @@
+package spark.streaming.dstream
+
+import spark.streaming.{DStream, Time}
+import spark.RDD
+import collection.mutable.ArrayBuffer
+import spark.rdd.UnionRDD
+
+private[streaming]
+class UnionDStream[T: ClassManifest](parents: Array[DStream[T]])
+ extends DStream[T](parents.head.ssc) {
+
+ if (parents.length == 0) {
+ throw new IllegalArgumentException("Empty array of parents")
+ }
+
+ if (parents.map(_.ssc).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different StreamingContexts")
+ }
+
+ if (parents.map(_.slideTime).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different slide times")
+ }
+
+ override def dependencies = parents.toList
+
+ override def slideTime: Time = parents.head.slideTime
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ val rdds = new ArrayBuffer[RDD[T]]()
+ parents.map(_.getOrCompute(validTime)).foreach(_ match {
+ case Some(rdd) => rdds += rdd
+ case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime)
+ })
+ if (rdds.size > 0) {
+ Some(new UnionRDD(ssc.sc, rdds))
+ } else {
+ None
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala
index e4d2a634f5..7718794cbf 100644
--- a/streaming/src/main/scala/spark/streaming/WindowedDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala
@@ -1,10 +1,11 @@
-package spark.streaming
+package spark.streaming.dstream
import spark.RDD
import spark.rdd.UnionRDD
import spark.storage.StorageLevel
+import spark.streaming.{Interval, Time, DStream}
-
+private[streaming]
class WindowedDStream[T: ClassManifest](
parent: DStream[T],
_windowTime: Time,
diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala b/streaming/src/main/scala/spark/streaming/examples/FileStream.scala
deleted file mode 100644
index 81938d30d4..0000000000
--- a/streaming/src/main/scala/spark/streaming/examples/FileStream.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-package spark.streaming.examples
-
-import spark.streaming.StreamingContext
-import spark.streaming.StreamingContext._
-import spark.streaming.Seconds
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.conf.Configuration
-
-
-object FileStream {
- def main(args: Array[String]) {
- if (args.length < 2) {
- System.err.println("Usage: FileStream <master> <new HDFS compatible directory>")
- System.exit(1)
- }
-
- // Create the context
- val ssc = new StreamingContext(args(0), "FileStream", Seconds(1))
-
- // Create the new directory
- val directory = new Path(args(1))
- val fs = directory.getFileSystem(new Configuration())
- if (fs.exists(directory)) throw new Exception("This directory already exists")
- fs.mkdirs(directory)
- fs.deleteOnExit(directory)
-
- // Create the FileInputDStream on the directory and use the
- // stream to count words in new files created
- val inputStream = ssc.textFileStream(directory.toString)
- val words = inputStream.flatMap(_.split(" "))
- val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
- wordCounts.print()
- ssc.start()
-
- // Creating new files in the directory
- val text = "This is a text file"
- for (i <- 1 to 30) {
- ssc.sc.parallelize((1 to (i * 10)).map(_ => text), 10)
- .saveAsTextFile(new Path(directory, i.toString).toString)
- Thread.sleep(1000)
- }
- Thread.sleep(5000) // Waiting for the file to be processed
- ssc.stop()
- System.exit(0)
- }
-} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
deleted file mode 100644
index b7bc15a1d5..0000000000
--- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-package spark.streaming.examples
-
-import spark.streaming._
-import spark.streaming.StreamingContext._
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.conf.Configuration
-
-object FileStreamWithCheckpoint {
-
- def main(args: Array[String]) {
-
- if (args.size != 3) {
- println("FileStreamWithCheckpoint <master> <directory> <checkpoint dir>")
- println("FileStreamWithCheckpoint restart <directory> <checkpoint dir>")
- System.exit(-1)
- }
-
- val directory = new Path(args(1))
- val checkpointDir = args(2)
-
- val ssc: StreamingContext = {
-
- if (args(0) == "restart") {
-
- // Recreated streaming context from specified checkpoint file
- new StreamingContext(checkpointDir)
-
- } else {
-
- // Create directory if it does not exist
- val fs = directory.getFileSystem(new Configuration())
- if (!fs.exists(directory)) fs.mkdirs(directory)
-
- // Create new streaming context
- val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint", Seconds(1))
- ssc_.checkpoint(checkpointDir)
-
- // Setup the streaming computation
- val inputStream = ssc_.textFileStream(directory.toString)
- val words = inputStream.flatMap(_.split(" "))
- val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
- wordCounts.print()
-
- ssc_
- }
- }
-
- // Start the stream computation
- startFileWritingThread(directory.toString)
- ssc.start()
- }
-
- def startFileWritingThread(directory: String) {
-
- val fs = new Path(directory).getFileSystem(new Configuration())
-
- val fileWritingThread = new Thread() {
- override def run() {
- val r = new scala.util.Random()
- val text = "This is a sample text file with a random number "
- while(true) {
- val number = r.nextInt()
- val file = new Path(directory, number.toString)
- val fos = fs.create(file)
- fos.writeChars(text + number)
- fos.close()
- println("Created text file " + file)
- Thread.sleep(1000)
- }
- }
- }
- fileWritingThread.start()
- }
-
-}
diff --git a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala b/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala
deleted file mode 100644
index 6cb2b4c042..0000000000
--- a/streaming/src/main/scala/spark/streaming/examples/GrepRaw.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-package spark.streaming.examples
-
-import spark.util.IntParam
-import spark.storage.StorageLevel
-
-import spark.streaming._
-import spark.streaming.StreamingContext._
-import spark.streaming.util.RawTextHelper._
-
-object GrepRaw {
- def main(args: Array[String]) {
- if (args.length != 5) {
- System.err.println("Usage: GrepRaw <master> <numStreams> <host> <port> <batchMillis>")
- System.exit(1)
- }
-
- val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args
-
- // Create the context
- val ssc = new StreamingContext(master, "GrepRaw", Milliseconds(batchMillis))
-
- // Warm up the JVMs on master and slave for JIT compilation to kick in
- warmUp(ssc.sc)
-
-
- val rawStreams = (1 to numStreams).map(_ =>
- ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray
- val union = new UnionDStream(rawStreams)
- union.filter(_.contains("Alice")).count().foreachRDD(r =>
- println("Grep count: " + r.collect().mkString))
- ssc.start()
- }
-}
diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
deleted file mode 100644
index fe4c2bf155..0000000000
--- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-package spark.streaming.examples
-
-import spark.storage.StorageLevel
-import spark.util.IntParam
-
-import spark.streaming._
-import spark.streaming.StreamingContext._
-import spark.streaming.util.RawTextHelper._
-
-import java.util.UUID
-
-object TopKWordCountRaw {
-
- def main(args: Array[String]) {
- if (args.length != 4) {
- System.err.println("Usage: WordCountRaw <master> <# streams> <port> <HDFS checkpoint directory> ")
- System.exit(1)
- }
-
- val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args
- val k = 10
-
- // Create the context, and set the checkpoint directory.
- // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts
- // periodically to HDFS
- val ssc = new StreamingContext(master, "TopKWordCountRaw", Seconds(1))
- ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1))
-
- // Warm up the JVMs on master and slave for JIT compilation to kick in
- /*warmUp(ssc.sc)*/
-
- // Set up the raw network streams that will connect to localhost:port to raw test
- // senders on the slaves and generate top K words of last 30 seconds
- val lines = (1 to numStreams).map(_ => {
- ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2)
- })
- val union = new UnionDStream(lines.toArray)
- val counts = union.mapPartitions(splitAndCountPartitions)
- val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10)
- val partialTopKWindowedCounts = windowedCounts.mapPartitions(topK(_, k))
- partialTopKWindowedCounts.foreachRDD(rdd => {
- val collectedCounts = rdd.collect
- println("Collected " + collectedCounts.size + " words from partial top words")
- println("Top " + k + " words are " + topK(collectedCounts.toIterator, k).mkString(","))
- })
-
- ssc.start()
- }
-}
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala
deleted file mode 100644
index 867a8f42c4..0000000000
--- a/streaming/src/main/scala/spark/streaming/examples/WordCountHdfs.scala
+++ /dev/null
@@ -1,25 +0,0 @@
-package spark.streaming.examples
-
-import spark.streaming.{Seconds, StreamingContext}
-import spark.streaming.StreamingContext._
-
-object WordCountHdfs {
- def main(args: Array[String]) {
- if (args.length < 2) {
- System.err.println("Usage: WordCountHdfs <master> <directory>")
- System.exit(1)
- }
-
- // Create the context
- val ssc = new StreamingContext(args(0), "WordCountHdfs", Seconds(2))
-
- // Create the FileInputDStream on the directory and use the
- // stream to count words in new files created
- val lines = ssc.textFileStream(args(1))
- val words = lines.flatMap(_.split(" "))
- val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
- wordCounts.print()
- ssc.start()
- }
-}
-
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
deleted file mode 100644
index a29c81d437..0000000000
--- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
+++ /dev/null
@@ -1,43 +0,0 @@
-package spark.streaming.examples
-
-import spark.storage.StorageLevel
-import spark.util.IntParam
-
-import spark.streaming._
-import spark.streaming.StreamingContext._
-import spark.streaming.util.RawTextHelper._
-
-import java.util.UUID
-
-object WordCountRaw {
-
- def main(args: Array[String]) {
- if (args.length != 4) {
- System.err.println("Usage: WordCountRaw <master> <# streams> <port> <HDFS checkpoint directory> ")
- System.exit(1)
- }
-
- val Array(master, IntParam(numStreams), IntParam(port), checkpointDir) = args
-
- // Create the context, and set the checkpoint directory.
- // Checkpoint directory is necessary for achieving fault-tolerance, by saving counts
- // periodically to HDFS
- val ssc = new StreamingContext(master, "WordCountRaw", Seconds(1))
- ssc.checkpoint(checkpointDir + "/" + UUID.randomUUID.toString, Seconds(1))
-
- // Warm up the JVMs on master and slave for JIT compilation to kick in
- warmUp(ssc.sc)
-
- // Set up the raw network streams that will connect to localhost:port to raw test
- // senders on the slaves and generate count of words of last 30 seconds
- val lines = (1 to numStreams).map(_ => {
- ssc.rawNetworkStream[String]("localhost", port, StorageLevel.MEMORY_ONLY_SER_2)
- })
- val union = new UnionDStream(lines.toArray)
- val counts = union.mapPartitions(splitAndCountPartitions)
- val windowedCounts = counts.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Seconds(1), 10)
- windowedCounts.foreachRDD(r => println("# unique words = " + r.count()))
-
- ssc.start()
- }
-}
diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala
index ed087e4ea8..974651f9f6 100644
--- a/streaming/src/main/scala/spark/streaming/util/Clock.scala
+++ b/streaming/src/main/scala/spark/streaming/util/Clock.scala
@@ -1,13 +1,12 @@
package spark.streaming.util
-import spark.streaming._
-
-trait Clock {
+private[streaming]
+trait Clock {
def currentTime(): Long
def waitTillTime(targetTime: Long): Long
}
-
+private[streaming]
class SystemClock() extends Clock {
val minPollTime = 25L
@@ -54,6 +53,7 @@ class SystemClock() extends Clock {
}
}
+private[streaming]
class ManualClock() extends Clock {
var time = 0L
diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
index dc55fd902b..db715cc295 100644
--- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
@@ -1,5 +1,6 @@
package spark.streaming.util
+private[streaming]
class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) {
val minPollTime = 25L
@@ -53,6 +54,7 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) =>
}
}
+private[streaming]
object RecurringTimer {
def main(args: Array[String]) {
diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties
index 02fe16866e..edfa1243fa 100644
--- a/streaming/src/test/resources/log4j.properties
+++ b/streaming/src/test/resources/log4j.properties
@@ -1,8 +1,11 @@
-# Set everything to be logged to the console
-log4j.rootCategory=WARN, console
-log4j.appender.console=org.apache.log4j.ConsoleAppender
-log4j.appender.console.layout=org.apache.log4j.PatternLayout
-log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+# Set everything to be logged to the file streaming/target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=false
+log4j.appender.file.file=streaming/target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
+
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index 0d82b2f1ea..920388bba9 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -42,7 +42,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
val stateStreamCheckpointInterval = Seconds(1)
// this ensure checkpointing occurs at least once
- val firstNumBatches = (stateStreamCheckpointInterval.millis / batchDuration.millis) * 2
+ val firstNumBatches = (stateStreamCheckpointInterval / batchDuration) * 2
val secondNumBatches = firstNumBatches
// Setup the streams
diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala
index 5b414117fc..4aa428bf64 100644
--- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala
@@ -133,7 +133,7 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter {
// Get the output buffer
val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
val output = outputStream.output
- val waitTime = (batchDuration.millis * (numBatches.toDouble + 0.5)).toLong
+ val waitTime = (batchDuration.milliseconds * (numBatches.toDouble + 0.5)).toLong
val startTime = System.currentTimeMillis()
try {
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
index ed9a659092..e71ba6ddc1 100644
--- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -1,5 +1,6 @@
package spark.streaming
+import dstream.SparkFlumeEvent
import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket}
import java.io.{File, BufferedWriter, OutputStreamWriter}
import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
@@ -317,7 +318,7 @@ class TestServer(port: Int) extends Logging {
}
}
} catch {
- case e: SocketException => println(e)
+ case e: SocketException => logError("TestServer error", e)
} finally {
logInfo("Connection closed")
if (!clientSocket.isClosed) clientSocket.close()
diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
index 8cc2f8ccfc..28bdd53c3c 100644
--- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
@@ -1,12 +1,16 @@
package spark.streaming
+import spark.streaming.dstream.{InputDStream, ForEachDStream}
+import spark.streaming.util.ManualClock
+
import spark.{RDD, Logging}
-import util.ManualClock
+
import collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
import collection.mutable.SynchronizedBuffer
+
import java.io.{ObjectInputStream, IOException}
+import org.scalatest.FunSuite
/**
* This is a input stream just for the testsuites. This is equivalent to a checkpointable,
@@ -35,7 +39,7 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[
* ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
*/
class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
- extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
+ extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
val collected = rdd.collect()
output += collected
}) {
@@ -70,6 +74,10 @@ trait TestSuiteBase extends FunSuite with Logging {
def actuallyWait = false
+ /**
+ * Set up required DStreams to test the DStream operation using the two sequences
+ * of input collections.
+ */
def setupStreams[U: ClassManifest, V: ClassManifest](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V]
@@ -90,6 +98,10 @@ trait TestSuiteBase extends FunSuite with Logging {
ssc
}
+ /**
+ * Set up required DStreams to test the binary operation using the sequence
+ * of input collections.
+ */
def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest](
input1: Seq[Seq[U]],
input2: Seq[Seq[V]],
@@ -173,6 +185,11 @@ trait TestSuiteBase extends FunSuite with Logging {
output
}
+ /**
+ * Verify whether the output values after running a DStream operation
+ * is same as the expected output values, by comparing the output
+ * collections either as lists (order matters) or sets (order does not matter)
+ */
def verifyOutput[V: ClassManifest](
output: Seq[Seq[V]],
expectedOutput: Seq[Seq[V]],
@@ -199,6 +216,10 @@ trait TestSuiteBase extends FunSuite with Logging {
logInfo("Output verified successfully")
}
+ /**
+ * Test unary DStream operation with a list of inputs, with number of
+ * batches to run same as the number of expected output values
+ */
def testOperation[U: ClassManifest, V: ClassManifest](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
@@ -208,6 +229,15 @@ trait TestSuiteBase extends FunSuite with Logging {
testOperation[U, V](input, operation, expectedOutput, -1, useSet)
}
+ /**
+ * Test unary DStream operation with a list of inputs
+ * @param input Sequence of input collections
+ * @param operation Binary DStream operation to be applied to the 2 inputs
+ * @param expectedOutput Sequence of expected output collections
+ * @param numBatches Number of batches to run the operation for
+ * @param useSet Compare the output values with the expected output values
+ * as sets (order matters) or as lists (order does not matter)
+ */
def testOperation[U: ClassManifest, V: ClassManifest](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
@@ -221,6 +251,10 @@ trait TestSuiteBase extends FunSuite with Logging {
verifyOutput[V](output, expectedOutput, useSet)
}
+ /**
+ * Test binary DStream operation with two lists of inputs, with number of
+ * batches to run same as the number of expected output values
+ */
def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
input1: Seq[Seq[U]],
input2: Seq[Seq[V]],
@@ -231,6 +265,16 @@ trait TestSuiteBase extends FunSuite with Logging {
testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet)
}
+ /**
+ * Test binary DStream operation with two lists of inputs
+ * @param input1 First sequence of input collections
+ * @param input2 Second sequence of input collections
+ * @param operation Binary DStream operation to be applied to the 2 inputs
+ * @param expectedOutput Sequence of expected output collections
+ * @param numBatches Number of batches to run the operation for
+ * @param useSet Compare the output values with the expected output values
+ * as sets (order matters) or as lists (order does not matter)
+ */
def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
input1: Seq[Seq[U]],
input2: Seq[Seq[V]],
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
index 3e20e16708..4bc5229465 100644
--- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -209,7 +209,7 @@ class WindowOperationsSuite extends TestSuiteBase {
val expectedOutput = bigGroupByOutput.map(_.map(x => (x._1, x._2.toSet)))
val windowTime = Seconds(2)
val slideTime = Seconds(1)
- val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt
val operation = (s: DStream[(String, Int)]) => {
s.groupByKeyAndWindow(windowTime, slideTime)
.map(x => (x._1, x._2.toSet))
@@ -223,7 +223,7 @@ class WindowOperationsSuite extends TestSuiteBase {
val expectedOutput = Seq( Seq(1), Seq(2), Seq(3), Seq(3), Seq(1), Seq(0))
val windowTime = Seconds(2)
val slideTime = Seconds(1)
- val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt
val operation = (s: DStream[Int]) => s.countByWindow(windowTime, slideTime)
testOperation(input, operation, expectedOutput, numBatches, true)
}
@@ -233,7 +233,7 @@ class WindowOperationsSuite extends TestSuiteBase {
val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3)))
val windowTime = Seconds(2)
val slideTime = Seconds(1)
- val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt
val operation = (s: DStream[(String, Int)]) => {
s.countByKeyAndWindow(windowTime, slideTime).map(x => (x._1, x._2.toInt))
}
@@ -251,7 +251,7 @@ class WindowOperationsSuite extends TestSuiteBase {
slideTime: Time = Seconds(1)
) {
test("window - " + name) {
- val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt
val operation = (s: DStream[Int]) => s.window(windowTime, slideTime)
testOperation(input, operation, expectedOutput, numBatches, true)
}
@@ -265,7 +265,7 @@ class WindowOperationsSuite extends TestSuiteBase {
slideTime: Time = Seconds(1)
) {
test("reduceByKeyAndWindow - " + name) {
- val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt
val operation = (s: DStream[(String, Int)]) => {
s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist()
}
@@ -281,7 +281,7 @@ class WindowOperationsSuite extends TestSuiteBase {
slideTime: Time = Seconds(1)
) {
test("reduceByKeyAndWindowInv - " + name) {
- val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val numBatches = expectedOutput.size * (slideTime / batchDuration).toInt
val operation = (s: DStream[(String, Int)]) => {
s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime)
.persist()