From abf588f47a26d0066f0b75d52b200a87bb085064 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 11:21:34 -0700 Subject: [SPARK-3749] [PySpark] fix bugs in broadcast large closure of RDD 1. broadcast is triggle unexpected 2. fd is leaked in JVM (also leak in parallelize()) 3. broadcast is not unpersisted in JVM after RDD is not be used any more. cc JoshRosen , sorry for these stupid bugs. Author: Davies Liu Closes #2603 from davies/fix_broadcast and squashes the following commits: 080a743 [Davies Liu] fix bugs in broadcast large closure of RDD --- .../org/apache/spark/api/python/PythonRDD.scala | 34 +++++++++++++--------- 1 file changed, 21 insertions(+), 13 deletions(-) (limited to 'core/src/main/scala') diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f9ff4ea6ca..9241414753 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -339,26 +339,34 @@ private[spark] object PythonRDD extends Logging { def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) - val objs = new collection.mutable.ArrayBuffer[Array[Byte]] try { - while (true) { - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - objs.append(obj) + val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + try { + while (true) { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + objs.append(obj) + } + } catch { + case eof: EOFException => {} } - } catch { - case eof: EOFException => {} + JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + } finally { + file.close() } - JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - sc.broadcast(obj) + try { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + sc.broadcast(obj) + } finally { + file.close() + } } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { -- cgit v1.2.3