aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py15
-rw-r--r--python/pyspark/tests.py6
2 files changed, 7 insertions, 14 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c9ac95d117..93e658eded 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1197,7 +1197,7 @@ class RDD(object):
[91, 92, 93]
"""
items = []
- totalParts = self._jrdd.partitions().size()
+ totalParts = self.getNumPartitions()
partsScanned = 0
while len(items) < num and partsScanned < totalParts:
@@ -1260,7 +1260,7 @@ class RDD(object):
>>> sc.parallelize([1]).isEmpty()
False
"""
- return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0
+ return self.getNumPartitions() == 0 or len(self.take(1)) == 0
def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
"""
@@ -2235,11 +2235,9 @@ def _prepare_for_python_RDD(sc, command, obj=None):
ser = CloudPickleSerializer()
pickled_command = ser.dumps((command, sys.version_info[:2]))
if len(pickled_command) > (1 << 20): # 1M
+ # The broadcast will have same life cycle as created PythonRDD
broadcast = sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
- # tracking the life cycle by obj
- if obj is not None:
- obj._broadcast = broadcast
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in sc._pickled_broadcast_vars],
sc._gateway._gateway_client)
@@ -2294,12 +2292,9 @@ class PipelinedRDD(RDD):
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
self.partitioner = prev.partitioner if self.preservesPartitioning else None
- self._broadcast = None
- def __del__(self):
- if self._broadcast:
- self._broadcast.unpersist()
- self._broadcast = None
+ def getNumPartitions(self):
+ return self._prev_jrdd.partitions().size()
@property
def _jrdd(self):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b938b9ce12..ee67e80d53 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -550,10 +550,8 @@ class RDDTests(ReusedPySparkTestCase):
data = [float(i) for i in xrange(N)]
rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
self.assertEquals(N, rdd.first())
- self.assertTrue(rdd._broadcast is not None)
- rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1)
- self.assertEqual(1, rdd.first())
- self.assertTrue(rdd._broadcast is None)
+ # regression test for SPARK-6886
+ self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))