aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-10-01 11:21:34 -0700
committerJosh Rosen <joshrosen@apache.org>2014-10-01 11:21:34 -0700
commitabf588f47a26d0066f0b75d52b200a87bb085064 (patch)
treed3e9b8a57688805262b7d00c99f596327912ea23
parent0bfd3afb00936b0f46ba613be0982e38bc7032b5 (diff)
downloadspark-abf588f47a26d0066f0b75d52b200a87bb085064.tar.gz
spark-abf588f47a26d0066f0b75d52b200a87bb085064.tar.bz2
spark-abf588f47a26d0066f0b75d52b200a87bb085064.zip
[SPARK-3749] [PySpark] fix bugs in broadcast large closure of RDD
1. broadcast is triggle unexpected 2. fd is leaked in JVM (also leak in parallelize()) 3. broadcast is not unpersisted in JVM after RDD is not be used any more. cc JoshRosen , sorry for these stupid bugs. Author: Davies Liu <davies.liu@gmail.com> Closes #2603 from davies/fix_broadcast and squashes the following commits: 080a743 [Davies Liu] fix bugs in broadcast large closure of RDD
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala34
-rw-r--r--python/pyspark/rdd.py12
-rw-r--r--python/pyspark/sql.py2
-rw-r--r--python/pyspark/tests.py8
4 files changed, 37 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f9ff4ea6ca..9241414753 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -339,26 +339,34 @@ private[spark] object PythonRDD extends Logging {
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
- val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
try {
- while (true) {
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- objs.append(obj)
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ try {
+ while (true) {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ objs.append(obj)
+ }
+ } catch {
+ case eof: EOFException => {}
}
- } catch {
- case eof: EOFException => {}
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ } finally {
+ file.close()
}
- JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- sc.broadcast(obj)
+ try {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ sc.broadcast(obj)
+ } finally {
+ file.close()
+ }
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 8ed89e2f97..dc6497772e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2073,6 +2073,12 @@ class PipelinedRDD(RDD):
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
+ self._broadcast = None
+
+ def __del__(self):
+ if self._broadcast:
+ self._broadcast.unpersist()
+ self._broadcast = None
@property
def _jrdd(self):
@@ -2087,9 +2093,9 @@ class PipelinedRDD(RDD):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
- if pickled_command > (1 << 20): # 1M
- broadcast = self.ctx.broadcast(pickled_command)
- pickled_command = ser.dumps(broadcast)
+ if len(pickled_command) > (1 << 20): # 1M
+ self._broadcast = self.ctx.broadcast(pickled_command)
+ pickled_command = ser.dumps(self._broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index d8bdf22355..974b5e287b 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -965,7 +965,7 @@ class SQLContext(object):
BatchedSerializer(PickleSerializer(), 1024))
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
- if pickled_command > (1 << 20): # 1M
+ if len(pickled_command) > (1 << 20): # 1M
broadcast = self._sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 7e2bbc9cb6..6fb6bc998c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -467,8 +467,12 @@ class TestRDDFunctions(PySparkTestCase):
def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
- m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
- self.assertEquals(N, m)
+ rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
+ self.assertEquals(N, rdd.first())
+ self.assertTrue(rdd._broadcast is not None)
+ rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1)
+ self.assertEqual(1, rdd.first())
+ self.assertTrue(rdd._broadcast is None)
def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))