aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockStore.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/storage/MemoryStore.scala31
-rw-r--r--core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala74
-rw-r--r--docs/configuration.md11
8 files changed, 226 insertions, 53 deletions
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 1daabecf23..872e892c04 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -71,10 +71,30 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val computedValues = rdd.computeOrReadCheckpoint(split, context)
// Persist the result, so long as the task is not running locally
if (context.runningLocally) { return computedValues }
- val elements = new ArrayBuffer[Any]
- elements ++= computedValues
- blockManager.put(key, elements, storageLevel, tellMaster = true)
- elements.iterator.asInstanceOf[Iterator[T]]
+ if (storageLevel.useDisk && !storageLevel.useMemory) {
+ // In the case that this RDD is to be persisted using DISK_ONLY
+ // the iterator will be passed directly to the blockManager (rather then
+ // caching it to an ArrayBuffer first), then the resulting block data iterator
+ // will be passed back to the user. If the iterator generates a lot of data,
+ // this means that it doesn't all have to be held in memory at one time.
+ // This could also apply to MEMORY_ONLY_SER storage, but we need to make sure
+ // blocks aren't dropped by the block store before enabling that.
+ blockManager.put(key, computedValues, storageLevel, tellMaster = true)
+ return blockManager.get(key) match {
+ case Some(values) =>
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
+ case None =>
+ logInfo("Failure to store %s".format(key))
+ throw new Exception("Block manager failed to return persisted valued")
+ }
+ } else {
+ // In this case the RDD is cached to an array buffer. This will save the results
+ // if we're dealing with a 'one-time' iterator
+ val elements = new ArrayBuffer[Any]
+ elements ++= computedValues
+ blockManager.put(key, elements, storageLevel, tellMaster = true)
+ return elements.iterator.asInstanceOf[Iterator[T]]
+ }
} finally {
loading.synchronized {
loading.remove(key)
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 33c1705ad7..bfa647f7f0 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -23,9 +23,28 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkConf
import org.apache.spark.util.ByteBufferInputStream
-private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
+private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf)
+ extends SerializationStream {
val objOut = new ObjectOutputStream(out)
- def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
+ var counter = 0
+ val counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000)
+
+ /**
+ * Calling reset to avoid memory leak:
+ * http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
+ * But only call it every 10,000th time to avoid bloated serialization streams (when
+ * the stream 'resets' object class descriptions have to be re-written)
+ */
+ def writeObject[T](t: T): SerializationStream = {
+ objOut.writeObject(t)
+ if (counterReset > 0 && counter >= counterReset) {
+ objOut.reset()
+ counter = 0
+ } else {
+ counter += 1
+ }
+ this
+ }
def flush() { objOut.flush() }
def close() { objOut.close() }
}
@@ -41,7 +60,7 @@ extends DeserializationStream {
def close() { objIn.close() }
}
-private[spark] class JavaSerializerInstance extends SerializerInstance {
+private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@@ -63,7 +82,7 @@ private[spark] class JavaSerializerInstance extends SerializerInstance {
}
def serializeStream(s: OutputStream): SerializationStream = {
- new JavaSerializationStream(s)
+ new JavaSerializationStream(s, conf)
}
def deserializeStream(s: InputStream): DeserializationStream = {
@@ -79,5 +98,5 @@ private[spark] class JavaSerializerInstance extends SerializerInstance {
* A Spark serializer that uses Java's built-in serialization.
*/
class JavaSerializer(conf: SparkConf) extends Serializer {
- def newInstance(): SerializerInstance = new JavaSerializerInstance
+ def newInstance(): SerializerInstance = new JavaSerializerInstance(conf)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index a734ddc1ef..977c24687c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -35,6 +35,12 @@ import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
import org.apache.spark.util._
+sealed trait Values
+
+case class ByteBufferValues(buffer: ByteBuffer) extends Values
+case class IteratorValues(iterator: Iterator[Any]) extends Values
+case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends Values
+
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -455,9 +461,7 @@ private[spark] class BlockManager(
def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
: Long = {
- val elements = new ArrayBuffer[Any]
- elements ++= values
- put(blockId, elements, level, tellMaster)
+ doPut(blockId, IteratorValues(values), level, tellMaster)
}
/**
@@ -479,7 +483,7 @@ private[spark] class BlockManager(
def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
tellMaster: Boolean = true) : Long = {
require(values != null, "Values is null")
- doPut(blockId, Left(values), level, tellMaster)
+ doPut(blockId, ArrayBufferValues(values), level, tellMaster)
}
/**
@@ -488,10 +492,11 @@ private[spark] class BlockManager(
def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel,
tellMaster: Boolean = true) {
require(bytes != null, "Bytes is null")
- doPut(blockId, Right(bytes), level, tellMaster)
+ doPut(blockId, ByteBufferValues(bytes), level, tellMaster)
}
- private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
+ private def doPut(blockId: BlockId,
+ data: Values,
level: StorageLevel, tellMaster: Boolean = true): Long = {
require(blockId != null, "BlockId is null")
require(level != null && level.isValid, "StorageLevel is null or invalid")
@@ -534,8 +539,9 @@ private[spark] class BlockManager(
// If we're storing bytes, then initiate the replication before storing them locally.
// This is faster as data is already serialized and ready to send.
- val replicationFuture = if (data.isRight && level.replication > 1) {
- val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper
+ val replicationFuture = if (data.isInstanceOf[ByteBufferValues] && level.replication > 1) {
+ // Duplicate doesn't copy the bytes, just creates a wrapper
+ val bufferView = data.asInstanceOf[ByteBufferValues].buffer.duplicate()
Future {
replicate(blockId, bufferView, level)
}
@@ -549,34 +555,43 @@ private[spark] class BlockManager(
var marked = false
try {
- data match {
- case Left(values) => {
- if (level.useMemory) {
- // Save it just to memory first, even if it also has useDisk set to true; we will
- // drop it to disk later if the memory store can't hold it.
- val res = memoryStore.putValues(blockId, values, level, true)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case Left(newIterator) => valuesAfterPut = newIterator
- }
- } else {
- // Save directly to disk.
- // Don't get back the bytes unless we replicate them.
- val askForBytes = level.replication > 1
- val res = diskStore.putValues(blockId, values, level, askForBytes)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case _ =>
- }
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will
+ // drop it to disk later if the memory store can't hold it.
+ val res = data match {
+ case IteratorValues(iterator) =>
+ memoryStore.putValues(blockId, iterator, level, true)
+ case ArrayBufferValues(array) =>
+ memoryStore.putValues(blockId, array, level, true)
+ case ByteBufferValues(bytes) => {
+ bytes.rewind();
+ memoryStore.putBytes(blockId, bytes, level)
+ }
+ }
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+
+ val res = data match {
+ case IteratorValues(iterator) =>
+ diskStore.putValues(blockId, iterator, level, askForBytes)
+ case ArrayBufferValues(array) =>
+ diskStore.putValues(blockId, array, level, askForBytes)
+ case ByteBufferValues(bytes) => {
+ bytes.rewind();
+ diskStore.putBytes(blockId, bytes, level)
}
}
- case Right(bytes) => {
- bytes.rewind()
- // Store it only in memory at first, even if useDisk is also set to true
- (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level)
- size = bytes.limit
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
}
}
@@ -605,8 +620,8 @@ private[spark] class BlockManager(
// values and need to serialize and replicate them now:
if (level.replication > 1) {
data match {
- case Right(bytes) => Await.ready(replicationFuture, Duration.Inf)
- case Left(values) => {
+ case ByteBufferValues(bytes) => Await.ready(replicationFuture, Duration.Inf)
+ case _ => {
val remoteStartTime = System.currentTimeMillis
// Serialize the block if not already done
if (bytesAfterPut == null) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
index b047644b88..9a9be047c7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
@@ -28,7 +28,7 @@ import org.apache.spark.Logging
*/
private[spark]
abstract class BlockStore(val blockManager: BlockManager) extends Logging {
- def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel)
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) : PutResult
/**
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
@@ -37,6 +37,9 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* @return a PutResult that contains the size of the data, as well as the values put if
* returnValues is true (if not, the result's data field can be null)
*/
+ def putValues(blockId: BlockId, values: Iterator[Any], level: StorageLevel,
+ returnValues: Boolean) : PutResult
+
def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
returnValues: Boolean) : PutResult
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index d1f07ddb24..36ee4bcc41 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -37,7 +37,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
diskManager.getBlockLocation(blockId).length
}
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = {
// So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate()
@@ -52,6 +52,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
+ return PutResult(bytes.limit(), Right(bytes.duplicate()))
}
override def putValues(
@@ -59,13 +60,22 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
+ : PutResult = {
+ return putValues(blockId, values.toIterator, level, returnValues)
+ }
+
+ override def putValues(
+ blockId: BlockId,
+ values: Iterator[Any],
+ level: StorageLevel,
+ returnValues: Boolean)
: PutResult = {
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
val file = diskManager.getFile(blockId)
val outputStream = new FileOutputStream(file)
- blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ blockManager.dataSerializeStream(blockId, outputStream, values)
val length = file.length
val timeTaken = System.currentTimeMillis - startTime
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 1814175651..b89212eaab 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -49,7 +49,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = {
// Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate()
bytes.rewind()
@@ -59,8 +59,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
elements ++= values
val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
tryToPut(blockId, elements, sizeEstimate, true)
+ PutResult(sizeEstimate, Left(values.toIterator))
} else {
tryToPut(blockId, bytes, bytes.limit, false)
+ PutResult(bytes.limit(), Right(bytes.duplicate()))
}
}
@@ -69,14 +71,33 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
- : PutResult = {
-
+ : PutResult = {
if (level.deserialized) {
val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef])
tryToPut(blockId, values, sizeEstimate, true)
- PutResult(sizeEstimate, Left(values.iterator))
+ PutResult(sizeEstimate, Left(values.toIterator))
+ } else {
+ val bytes = blockManager.dataSerialize(blockId, values.toIterator)
+ tryToPut(blockId, bytes, bytes.limit, false)
+ PutResult(bytes.limit(), Right(bytes.duplicate()))
+ }
+ }
+
+ override def putValues(
+ blockId: BlockId,
+ values: Iterator[Any],
+ level: StorageLevel,
+ returnValues: Boolean)
+ : PutResult = {
+
+ if (level.deserialized) {
+ val valueEntries = new ArrayBuffer[Any]()
+ valueEntries ++= values
+ val sizeEstimate = SizeEstimator.estimate(valueEntries.asInstanceOf[AnyRef])
+ tryToPut(blockId, valueEntries, sizeEstimate, true)
+ PutResult(sizeEstimate, Left(valueEntries.toIterator))
} else {
- val bytes = blockManager.dataSerialize(blockId, values.iterator)
+ val bytes = blockManager.dataSerialize(blockId, values)
tryToPut(blockId, bytes, bytes.limit, false)
PutResult(bytes.limit(), Right(bytes.duplicate()))
}
diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala
new file mode 100644
index 0000000000..b843b4c629
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import org.scalatest.FunSuite
+import org.apache.spark.{SharedSparkContext, SparkConf, LocalSparkContext, SparkContext}
+
+
+class FlatmapIteratorSuite extends FunSuite with LocalSparkContext {
+ /* Tests the ability of Spark to deal with user provided iterators from flatMap
+ * calls, that may generate more data then available memory. In any
+ * memory based persistance Spark will unroll the iterator into an ArrayBuffer
+ * for caching, however in the case that the use defines DISK_ONLY persistance,
+ * the iterator will be fed directly to the serializer and written to disk.
+ *
+ * This also tests the ObjectOutputStream reset rate. When serializing using the
+ * Java serialization system, the serializer caches objects to prevent writing redundant
+ * data, however that stops GC of those objects. By calling 'reset' you flush that
+ * info from the serializer, and allow old objects to be GC'd
+ */
+ test("Flatmap Iterator to Disk") {
+ val sconf = new SparkConf().setMaster("local-cluster[1,1,512]")
+ .setAppName("iterator_to_disk_test")
+ sc = new SparkContext(sconf)
+ val expand_size = 100
+ val data = sc.parallelize((1 to 5).toSeq).
+ flatMap( x => Stream.range(0, expand_size))
+ var persisted = data.persist(StorageLevel.DISK_ONLY)
+ println(persisted.count())
+ assert(persisted.count()===500)
+ assert(persisted.filter(_==1).count()===5)
+ }
+
+ test("Flatmap Iterator to Memory") {
+ val sconf = new SparkConf().setMaster("local-cluster[1,1,512]")
+ .setAppName("iterator_to_disk_test")
+ sc = new SparkContext(sconf)
+ val expand_size = 100
+ val data = sc.parallelize((1 to 5).toSeq).
+ flatMap(x => Stream.range(0, expand_size))
+ var persisted = data.persist(StorageLevel.MEMORY_ONLY)
+ println(persisted.count())
+ assert(persisted.count()===500)
+ assert(persisted.filter(_==1).count()===5)
+ }
+
+ test("Serializer Reset") {
+ val sconf = new SparkConf().setMaster("local-cluster[1,1,512]")
+ .setAppName("serializer_reset_test")
+ .set("spark.serializer.objectStreamReset", "10")
+ sc = new SparkContext(sconf)
+ val expand_size = 500
+ val data = sc.parallelize(Seq(1,2)).
+ flatMap(x => Stream.range(1, expand_size).
+ map(y => "%d: string test %d".format(y,x)))
+ var persisted = data.persist(StorageLevel.MEMORY_ONLY_SER)
+ assert(persisted.filter(_.startsWith("1:")).count()===2)
+ }
+
+}
diff --git a/docs/configuration.md b/docs/configuration.md
index dc5553f3da..017d509854 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -245,6 +245,17 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td>spark.serializer.objectStreamReset</td>
+ <td>10000</td>
+ <td>
+ When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches
+ objects to prevent writing redundant data, however that stops garbage collection of those
+ objects. By calling 'reset' you flush that info from the serializer, and allow old
+ objects to be collected. To turn off this periodic reset set it to a value of <= 0.
+ By default it will reset the serializer every 10,000 objects.
+ </td>
+</tr>
+<tr>
<td>spark.broadcast.factory</td>
<td>org.apache.spark.broadcast.<br />HttpBroadcastFactory</td>
<td>