aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-12 20:02:55 -0700
committerReynold Xin <rxin@databricks.com>2015-08-12 20:02:55 -0700
commit7c35746c916cf0019367850e75a080d7e739dba0 (patch)
tree7ae755ca944cb4278981be1f250d08622a6d3abc /sql
parent7b13ed27c1296cf76d0946e400f3449c335c8471 (diff)
downloadspark-7c35746c916cf0019367850e75a080d7e739dba0.tar.gz
spark-7c35746c916cf0019367850e75a080d7e739dba0.tar.bz2
spark-7c35746c916cf0019367850e75a080d7e739dba0.zip
[SPARK-9827] [SQL] fix fd leak in UnsafeRowSerializer
Currently, UnsafeRowSerializer does not close the InputStream, will cause fd leak if the InputStream has an open fd in it. TODO: the fd could still be leaked, if any items in the stream is not consumed. Currently it replies on GC to close the fd in this case. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes #8116 from davies/fd_leak.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala31
2 files changed, 30 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 3860c4bba9..5c18558f9b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -108,6 +108,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
new Iterator[(Int, UnsafeRow)] {
private[this] var rowSize: Int = dIn.readInt()
+ if (rowSize == EOF) dIn.close()
override def hasNext: Boolean = rowSize != EOF
@@ -119,6 +120,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
rowSize = dIn.readInt() // read the next row's size
if (rowSize == EOF) { // We are returning the last row in this stream
+ dIn.close()
val _rowTuple = rowTuple
// Null these out so that the byte array can be garbage collected once the entire
// iterator has been consumed
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 40b47ae18d..bd02c73a26 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
@@ -25,6 +25,18 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
+
+/**
+ * used to test close InputStream in UnsafeRowSerializer
+ */
+class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) {
+ var closed: Boolean = false
+ override def close(): Unit = {
+ closed = true
+ super.close()
+ }
+}
+
class UnsafeRowSerializerSuite extends SparkFunSuite {
private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
@@ -52,8 +64,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
serializerStream.writeValue(unsafeRow)
}
serializerStream.close()
- val deserializerIter = serializer.deserializeStream(
- new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator
+ val input = new ClosableByteArrayInputStream(baos.toByteArray)
+ val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator
for (expectedRow <- unsafeRows) {
val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2
assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes)
@@ -61,5 +73,18 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
assert(expectedRow.getInt(1) === actualRow.getInt(1))
}
assert(!deserializerIter.hasNext)
+ assert(input.closed)
+ }
+
+ test("close empty input stream") {
+ val baos = new ByteArrayOutputStream()
+ val dout = new DataOutputStream(baos)
+ dout.writeInt(-1) // EOF
+ dout.flush()
+ val input = new ClosableByteArrayInputStream(baos.toByteArray)
+ val serializer = new UnsafeRowSerializer(numFields = 2).newInstance()
+ val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator
+ assert(!deserializerIter.hasNext)
+ assert(input.closed)
}
}