diff options
author | Davies Liu <davies@databricks.com> | 2015-08-12 20:02:55 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-12 20:02:55 -0700 |
commit | 7c35746c916cf0019367850e75a080d7e739dba0 (patch) | |
tree | 7ae755ca944cb4278981be1f250d08622a6d3abc | |
parent | 7b13ed27c1296cf76d0946e400f3449c335c8471 (diff) | |
download | spark-7c35746c916cf0019367850e75a080d7e739dba0.tar.gz spark-7c35746c916cf0019367850e75a080d7e739dba0.tar.bz2 spark-7c35746c916cf0019367850e75a080d7e739dba0.zip |
[SPARK-9827] [SQL] fix fd leak in UnsafeRowSerializer
Currently, UnsafeRowSerializer does not close the InputStream, will cause fd leak if the InputStream has an open fd in it.
TODO: the fd could still be leaked, if any items in the stream is not consumed. Currently it replies on GC to close the fd in this case.
cc JoshRosen
Author: Davies Liu <davies@databricks.com>
Closes #8116 from davies/fd_leak.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala | 2 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala | 31 |
2 files changed, 30 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 3860c4bba9..5c18558f9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -108,6 +108,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { new Iterator[(Int, UnsafeRow)] { private[this] var rowSize: Int = dIn.readInt() + if (rowSize == EOF) dIn.close() override def hasNext: Boolean = rowSize != EOF @@ -119,6 +120,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream + dIn.close() val _rowTuple = rowTuple // Null these out so that the byte array can be garbage collected once the entire // iterator has been consumed diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 40b47ae18d..bd02c73a26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -25,6 +25,18 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ + +/** + * used to test close InputStream in UnsafeRowSerializer + */ +class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) { + var closed: Boolean = false + override def close(): Unit = { + closed = true + super.close() + } +} + class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { @@ -52,8 +64,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { serializerStream.writeValue(unsafeRow) } serializerStream.close() - val deserializerIter = serializer.deserializeStream( - new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator for (expectedRow <- unsafeRows) { val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) @@ -61,5 +73,18 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { assert(expectedRow.getInt(1) === actualRow.getInt(1)) } assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("close empty input stream") { + val baos = new ByteArrayOutputStream() + val dout = new DataOutputStream(baos) + dout.writeInt(-1) // EOF + dout.flush() + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + assert(!deserializerIter.hasNext) + assert(input.closed) } } |