aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala45
-rw-r--r--docs/configuration.md7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Time.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala16
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala76
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala13
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala1
7 files changed, 135 insertions, 25 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index e6c4a6d379..c64da8804d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -19,24 +19,30 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
-import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark._
import org.apache.spark.storage.{BlockId, BlockManager}
+import scala.Some
private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
val index = idx
}
private[spark]
-class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId])
+class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
+ @volatile private var _isValid = true
- override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
- new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
- }).toArray
+ override def getPartitions: Array[Partition] = {
+ assertValid()
+ (0 until blockIds.size).map(i => {
+ new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
+ }).toArray
+ }
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ assertValid()
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
blockManager.get(blockId) match {
@@ -47,7 +53,36 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId
}
override def getPreferredLocations(split: Partition): Seq[String] = {
+ assertValid()
locations_(split.asInstanceOf[BlockRDDPartition].blockId)
}
+
+ /**
+ * Remove the data blocks that this BlockRDD is made from. NOTE: This is an
+ * irreversible operation, as the data in the blocks cannot be recovered back
+ * once removed. Use it with caution.
+ */
+ private[spark] def removeBlocks() {
+ blockIds.foreach { blockId =>
+ sc.env.blockManager.master.removeBlock(blockId)
+ }
+ _isValid = false
+ }
+
+ /**
+ * Whether this BlockRDD is actually usable. This will be false if the data blocks have been
+ * removed using `this.removeBlocks`.
+ */
+ private[spark] def isValid: Boolean = {
+ _isValid
+ }
+
+ /** Check if this BlockRDD is valid. If not valid, exception is thrown. */
+ private[spark] def assertValid() {
+ if (!_isValid) {
+ throw new SparkException(
+ "Attempted to use %s after its blocks have been removed!".format(toString))
+ }
+ }
}
diff --git a/docs/configuration.md b/docs/configuration.md
index e7e1dd56cf..8d3442625b 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -469,10 +469,13 @@ Apart from these, the following properties are also available, and may be useful
</tr>
<tr>
<td>spark.streaming.unpersist</td>
- <td>false</td>
+ <td>true</td>
<td>
Force RDDs generated and persisted by Spark Streaming to be automatically unpersisted from
- Spark's memory. Setting this to true is likely to reduce Spark's RDD memory usage.
+ Spark's memory. The raw input data received by Spark Streaming is also automatically cleared.
+ Setting this to false will allow the raw data and persisted RDDs to be accessible outside the
+ streaming application as they will not be cleared automatically. But it comes at the cost of
+ higher memory usage in Spark.
</td>
</tr>
<tr>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala
index 6a6b00a778..37b3b28fa0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala
@@ -68,5 +68,5 @@ case class Time(private val millis: Long) {
}
object Time {
- val ordering = Ordering.by((time: Time) => time.millis)
+ implicit val ordering = Ordering.by((time: Time) => time.millis)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index d393cc03cb..f69f69e0c4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -25,7 +25,7 @@ import scala.reflect.ClassTag
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import org.apache.spark.Logging
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.MetadataCleaner
import org.apache.spark.streaming._
@@ -340,13 +340,23 @@ abstract class DStream[T: ClassTag] (
* this to clear their own metadata along with the generated RDDs.
*/
private[streaming] def clearMetadata(time: Time) {
+ val unpersistData = ssc.conf.getBoolean("spark.streaming.unpersist", true)
val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
logDebug("Clearing references to old RDDs: [" +
oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]")
generatedRDDs --= oldRDDs.keys
- if (ssc.conf.getBoolean("spark.streaming.unpersist", false)) {
+ if (unpersistData) {
logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", "))
- oldRDDs.values.foreach(_.unpersist(false))
+ oldRDDs.values.foreach { rdd =>
+ rdd.unpersist(false)
+ // Explicitly remove blocks of BlockRDD
+ rdd match {
+ case b: BlockRDD[_] =>
+ logInfo("Removing blocks of RDD " + b + " of time " + time)
+ b.removeBlocks()
+ case _ =>
+ }
+ }
}
logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " +
(time - rememberDuration) + ": " + oldRDDs.keys.mkString(", "))
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 8aec27e394..4792ca1f8a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.streaming
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.SparkContext._
import util.ManualClock
@@ -27,6 +27,8 @@ import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.streaming.dstream.{WindowedDStream, DStream}
import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import scala.reflect.ClassTag
+import org.apache.spark.storage.StorageLevel
+import scala.collection.mutable
class BasicOperationsSuite extends TestSuiteBase {
test("map") {
@@ -450,6 +452,78 @@ class BasicOperationsSuite extends TestSuiteBase {
assert(!stateStream.generatedRDDs.contains(Time(4000)))
}
+ test("rdd cleanup - input blocks and persisted RDDs") {
+ // Actually receive data over through receiver to create BlockRDDs
+
+ // Start the server
+ val testServer = new TestServer()
+ testServer.start()
+
+ // Set up the streaming context and input streams
+ val ssc = new StreamingContext(conf, batchDuration)
+ val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
+ val mappedStream = networkStream.map(_ + ".").persist()
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
+ val outputStream = new TestOutputStream(mappedStream, outputBuffer)
+
+ outputStream.register()
+ ssc.start()
+
+ // Feed data to the server to send to the network receiver
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq(1, 2, 3, 4, 5, 6)
+
+ val blockRdds = new mutable.HashMap[Time, BlockRDD[_]]
+ val persistentRddIds = new mutable.HashMap[Time, Int]
+
+ def collectRddInfo() { // get all RDD info required for verification
+ networkStream.generatedRDDs.foreach { case (time, rdd) =>
+ blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]]
+ }
+ mappedStream.generatedRDDs.foreach { case (time, rdd) =>
+ persistentRddIds(time) = rdd.id
+ }
+ }
+
+ Thread.sleep(200)
+ for (i <- 0 until input.size) {
+ testServer.send(input(i).toString + "\n")
+ Thread.sleep(200)
+ clock.addToTime(batchDuration.milliseconds)
+ collectRddInfo()
+ }
+
+ Thread.sleep(200)
+ collectRddInfo()
+ logInfo("Stopping server")
+ testServer.stop()
+ logInfo("Stopping context")
+
+ // verify data has been received
+ assert(outputBuffer.size > 0)
+ assert(blockRdds.size > 0)
+ assert(persistentRddIds.size > 0)
+
+ import Time._
+
+ val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max)
+ val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min)
+ val latestBlockRdd = blockRdds(blockRdds.keySet.max)
+ val earliestBlockRdd = blockRdds(blockRdds.keySet.min)
+ // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted
+ assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId))
+ assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId))
+
+ // verify that the latest input blocks are present but the earliest blocks have been removed
+ assert(latestBlockRdd.isValid)
+ assert(latestBlockRdd.collect != null)
+ assert(!earliestBlockRdd.isValid)
+ earliestBlockRdd.blockIds.foreach { blockId =>
+ assert(!ssc.sparkContext.env.blockManager.master.contains(blockId))
+ }
+ ssc.stop()
+ }
+
/** Test cleanup of RDDs in DStream metadata */
def runCleanupTest[T: ClassTag](
conf2: SparkConf,
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 3bad871b5c..b55b7834c9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -42,8 +42,6 @@ import org.apache.spark.streaming.receiver.{ActorHelper, Receiver}
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
- val testPort = 9999
-
test("socket input stream") {
// Start the server
val testServer = new TestServer()
@@ -288,17 +286,6 @@ class TestServer(portToBind: Int = 0) extends Logging {
def port = serverSocket.getLocalPort
}
-object TestServer {
- def main(args: Array[String]) {
- val s = new TestServer()
- s.start()
- while(true) {
- Thread.sleep(1000)
- s.send("hello")
- }
- }
-}
-
/** This is an actor for testing actor input stream */
class TestActor(port: Int) extends Actor with ActorHelper {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
index 45304c76b0..ff3619a590 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
@@ -29,6 +29,7 @@ import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
+import scala.language.postfixOps
/** Testsuite for testing the network receiver behavior */
class NetworkReceiverSuite extends FunSuite with Timeouts {