aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b5e28c4980..d6afc1cdaa 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1550,6 +1550,37 @@ class ContextTests(unittest.TestCase):
sc.stop()
self.assertEqual(SparkContext._active_spark_context, None)
+ def test_progress_api(self):
+ with SparkContext() as sc:
+ sc.setJobGroup('test_progress_api', '', True)
+
+ rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
+ t = threading.Thread(target=rdd.collect)
+ t.daemon = True
+ t.start()
+ # wait for scheduler to start
+ time.sleep(1)
+
+ tracker = sc.statusTracker()
+ jobIds = tracker.getJobIdsForGroup('test_progress_api')
+ self.assertEqual(1, len(jobIds))
+ job = tracker.getJobInfo(jobIds[0])
+ self.assertEqual(1, len(job.stageIds))
+ stage = tracker.getStageInfo(job.stageIds[0])
+ self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
+
+ sc.cancelAllJobs()
+ t.join()
+ # wait for event listener to update the status
+ time.sleep(1)
+
+ job = tracker.getJobInfo(jobIds[0])
+ self.assertEqual('FAILED', job.status)
+ self.assertEqual([], tracker.getActiveJobsIds())
+ self.assertEqual([], tracker.getActiveStageIds())
+
+ sc.stop()
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):