aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-04-15 12:58:02 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-04-15 12:58:02 -0700
commitf11288d5272bc18585b8cad4ee3bd59eade7c296 (patch)
tree0016cd5ccb7c7025d00ce79c3aefc3b9799420c5 /python/pyspark/tests.py
parent6c5ed8a6d552abd967d27cdb94b68d46ccb57221 (diff)
downloadspark-f11288d5272bc18585b8cad4ee3bd59eade7c296.tar.gz
spark-f11288d5272bc18585b8cad4ee3bd59eade7c296.tar.bz2
spark-f11288d5272bc18585b8cad4ee3bd59eade7c296.zip
[SPARK-6886] [PySpark] fix big closure with shuffle
Currently, the created broadcast object will have same life cycle as RDD in Python. For multistage jobs, an PythonRDD will be created in JVM and the RDD in Python may be GCed, then the broadcast will be destroyed in JVM before the PythonRDD. This PR change to use PythonRDD to track the lifecycle of the broadcast object. It also have a refactor about getNumPartitions() to avoid unnecessary creation of PythonRDD, which could be heavy. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes #5496 from davies/big_closure and squashes the following commits: 9a0ea4c [Davies Liu] fix big closure with shuffle
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b938b9ce12..ee67e80d53 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -550,10 +550,8 @@ class RDDTests(ReusedPySparkTestCase):
data = [float(i) for i in xrange(N)]
rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
self.assertEquals(N, rdd.first())
- self.assertTrue(rdd._broadcast is not None)
- rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1)
- self.assertEqual(1, rdd.first())
- self.assertTrue(rdd._broadcast is None)
+ # regression test for SPARK-6886
+ self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))