aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index fe314c54a1..c383d9ab67 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -69,6 +69,7 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer,
from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
from pyspark import shuffle
from pyspark.profiler import BasicProfiler
+from pyspark.taskcontext import TaskContext
_have_scipy = False
_have_numpy = False
@@ -478,6 +479,70 @@ class AddFileTests(PySparkTestCase):
self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
+class TaskContextTests(PySparkTestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ # Allow retries even though they are normally disabled in local mode
+ self.sc = SparkContext('local[4, 2]', class_name)
+
+ def test_stage_id(self):
+ """Test the stage ids are available and incrementing as expected."""
+ rdd = self.sc.parallelize(range(10))
+ stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+ stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+ # Test using the constructor directly rather than the get()
+ stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
+ self.assertEqual(stage1 + 1, stage2)
+ self.assertEqual(stage1 + 2, stage3)
+ self.assertEqual(stage2 + 1, stage3)
+
+ def test_partition_id(self):
+ """Test the partition id."""
+ rdd1 = self.sc.parallelize(range(10), 1)
+ rdd2 = self.sc.parallelize(range(10), 2)
+ pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
+ pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
+ self.assertEqual(0, pids1[0])
+ self.assertEqual(0, pids1[9])
+ self.assertEqual(0, pids2[0])
+ self.assertEqual(1, pids2[9])
+
+ def test_attempt_number(self):
+ """Verify the attempt numbers are correctly reported."""
+ rdd = self.sc.parallelize(range(10))
+ # Verify a simple job with no failures
+ attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect()
+ map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
+
+ def fail_on_first(x):
+ """Fail on the first attempt so we get a positive attempt number"""
+ tc = TaskContext.get()
+ attempt_number = tc.attemptNumber()
+ partition_id = tc.partitionId()
+ attempt_id = tc.taskAttemptId()
+ if attempt_number == 0 and partition_id == 0:
+ raise Exception("Failing on first attempt")
+ else:
+ return [x, partition_id, attempt_number, attempt_id]
+ result = rdd.map(fail_on_first).collect()
+ # We should re-submit the first partition to it but other partitions should be attempt 0
+ self.assertEqual([0, 0, 1], result[0][0:3])
+ self.assertEqual([9, 3, 0], result[9][0:3])
+ first_partition = filter(lambda x: x[1] == 0, result)
+ map(lambda x: self.assertEqual(1, x[2]), first_partition)
+ other_partitions = filter(lambda x: x[1] != 0, result)
+ map(lambda x: self.assertEqual(0, x[2]), other_partitions)
+ # The task attempt id should be different
+ self.assertTrue(result[0][3] != result[9][3])
+
+ def test_tc_on_driver(self):
+ """Verify that getting the TaskContext on the driver returns None."""
+ tc = TaskContext.get()
+ self.assertTrue(tc is None)
+
+
class RDDTests(ReusedPySparkTestCase):
def test_range(self):