From 67e085ef6dd62774095f3187844c091db1a6a72c Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 8 Jul 2016 12:37:26 -0700 Subject: [SPARK-16420] Ensure compression streams are closed. ## What changes were proposed in this pull request? This uses the try/finally pattern to ensure streams are closed after use. `UnsafeShuffleWriter` wasn't closing compression streams, causing them to leak resources until garbage collected. This was causing a problem with codecs that use off-heap memory. ## How was this patch tested? Current tests are sufficient. This should not change behavior. Author: Ryan Blue Closes #14093 from rdblue/SPARK-16420-unsafe-shuffle-writer-leak. --- .../spark/network/util/LimitedInputStream.java | 23 ++++++++++++++++++++++ .../spark/shuffle/sort/UnsafeShuffleWriter.java | 17 +++++++++++----- .../apache/spark/broadcast/TorrentBroadcast.scala | 13 +++++++++--- .../spark/serializer/GenericAvroSerializer.scala | 15 +++++++++++--- 4 files changed, 57 insertions(+), 11 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java index 922c37a10e..e79eef0325 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -48,11 +48,27 @@ import com.google.common.base.Preconditions; * use this functionality in both a Guava 11 environment and a Guava >14 environment. */ public final class LimitedInputStream extends FilterInputStream { + private final boolean closeWrappedStream; private long left; private long mark = -1; public LimitedInputStream(InputStream in, long limit) { + this(in, limit, true); + } + + /** + * Create a LimitedInputStream that will read {@code limit} bytes from {@code in}. + *

+ * If {@code closeWrappedStream} is true, this will close {@code in} when it is closed. + * Otherwise, the stream is left open for reading its remaining content. + * + * @param in a {@link InputStream} to read from + * @param limit the number of bytes to read + * @param closeWrappedStream whether to close {@code in} when {@link #close} is called + */ + public LimitedInputStream(InputStream in, long limit, boolean closeWrappedStream) { super(in); + this.closeWrappedStream = closeWrappedStream; Preconditions.checkNotNull(in); Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); left = limit; @@ -102,4 +118,11 @@ public final class LimitedInputStream extends FilterInputStream { left -= skipped; return skipped; } + + @Override + public void close() throws IOException { + if (closeWrappedStream) { + super.close(); + } + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 05fa04c44d..08fb887bbd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -349,12 +349,19 @@ public class UnsafeShuffleWriter extends ShuffleWriter { for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = - new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + InputStream partitionInputStream = null; + boolean innerThrewException = true; + try { + partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, mergedFileOutputStream); + innerThrewException = false; + } finally { + Closeables.close(partitionInputStream, innerThrewException); } - ByteStreams.copy(partitionInputStream, mergedFileOutputStream); } } mergedFileOutputStream.flush(); diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 632b0ae9c2..e8d6d587b4 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -232,7 +232,11 @@ private object TorrentBroadcast extends Logging { val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos) val ser = serializer.newInstance() val serOut = ser.serializeStream(out) - serOut.writeObject[T](obj).close() + Utils.tryWithSafeFinally { + serOut.writeObject[T](obj) + } { + serOut.close() + } cbbos.toChunkedByteBuffer.getChunks() } @@ -246,8 +250,11 @@ private object TorrentBroadcast extends Logging { val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() + val obj = Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } obj } diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index d17a7894fd..f0ed41f690 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -32,6 +32,7 @@ import org.apache.commons.io.IOUtils import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec +import org.apache.spark.util.Utils /** * Custom serializer used for generic Avro records. If the user registers the schemas @@ -72,8 +73,11 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { val bos = new ByteArrayOutputStream() val out = codec.compressedOutputStream(bos) - out.write(schema.toString.getBytes(StandardCharsets.UTF_8)) - out.close() + Utils.tryWithSafeFinally { + out.write(schema.toString.getBytes(StandardCharsets.UTF_8)) + } { + out.close() + } bos.toByteArray }) @@ -86,7 +90,12 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) schemaBytes.array(), schemaBytes.arrayOffset() + schemaBytes.position(), schemaBytes.remaining()) - val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) + val in = codec.compressedInputStream(bis) + val bytes = Utils.tryWithSafeFinally { + IOUtils.toByteArray(in) + } { + in.close() + } new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8)) }) -- cgit v1.2.3