diff options
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r-- | python/pyspark/tests.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b5e28c4980..d6afc1cdaa 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1550,6 +1550,37 @@ class ContextTests(unittest.TestCase): sc.stop() self.assertEqual(SparkContext._active_spark_context, None) + def test_progress_api(self): + with SparkContext() as sc: + sc.setJobGroup('test_progress_api', '', True) + + rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) + t = threading.Thread(target=rdd.collect) + t.daemon = True + t.start() + # wait for scheduler to start + time.sleep(1) + + tracker = sc.statusTracker() + jobIds = tracker.getJobIdsForGroup('test_progress_api') + self.assertEqual(1, len(jobIds)) + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual(1, len(job.stageIds)) + stage = tracker.getStageInfo(job.stageIds[0]) + self.assertEqual(rdd.getNumPartitions(), stage.numTasks) + + sc.cancelAllJobs() + t.join() + # wait for event listener to update the status + time.sleep(1) + + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual('FAILED', job.status) + self.assertEqual([], tracker.getActiveJobsIds()) + self.assertEqual([], tracker.getActiveStageIds()) + + sc.stop() + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): |