aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala34
-rw-r--r--python/pyspark/rdd.py12
-rw-r--r--python/pyspark/sql.py2
-rw-r--r--python/pyspark/tests.py8
4 files changed, 37 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f9ff4ea6ca..9241414753 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -339,26 +339,34 @@ private[spark] object PythonRDD extends Logging {
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
- val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
try {
- while (true) {
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- objs.append(obj)
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ try {
+ while (true) {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ objs.append(obj)
+ }
+ } catch {
+ case eof: EOFException => {}
}
- } catch {
- case eof: EOFException => {}
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ } finally {
+ file.close()
}
- JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- sc.broadcast(obj)
+ try {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ sc.broadcast(obj)
+ } finally {
+ file.close()
+ }
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 8ed89e2f97..dc6497772e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2073,6 +2073,12 @@ class PipelinedRDD(RDD):
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
+ self._broadcast = None
+
+ def __del__(self):
+ if self._broadcast:
+ self._broadcast.unpersist()
+ self._broadcast = None
@property
def _jrdd(self):
@@ -2087,9 +2093,9 @@ class PipelinedRDD(RDD):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
- if pickled_command > (1 << 20): # 1M
- broadcast = self.ctx.broadcast(pickled_command)
- pickled_command = ser.dumps(broadcast)
+ if len(pickled_command) > (1 << 20): # 1M
+ self._broadcast = self.ctx.broadcast(pickled_command)
+ pickled_command = ser.dumps(self._broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index d8bdf22355..974b5e287b 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -965,7 +965,7 @@ class SQLContext(object):
BatchedSerializer(PickleSerializer(), 1024))
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
- if pickled_command > (1 << 20): # 1M
+ if len(pickled_command) > (1 << 20): # 1M
broadcast = self._sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 7e2bbc9cb6..6fb6bc998c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -467,8 +467,12 @@ class TestRDDFunctions(PySparkTestCase):
def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
- m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
- self.assertEquals(N, m)
+ rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
+ self.assertEquals(N, rdd.first())
+ self.assertTrue(rdd._broadcast is not None)
+ rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1)
+ self.assertEqual(1, rdd.first())
+ self.assertTrue(rdd._broadcast is None)
def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))