aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-12-20 15:51:21 -0800
committerReynold Xin <rxin@databricks.com>2016-12-20 15:51:21 -0800
commit047a9d92caa1f0af2305e4afeba8339abf32518b (patch)
treeae8588410abee1171dc0cf18ffcbd71a99f62399 /python
parentcaed89321fdabe83e46451ca4e968f86481ad500 (diff)
downloadspark-047a9d92caa1f0af2305e4afeba8339abf32518b.tar.gz
spark-047a9d92caa1f0af2305e4afeba8339abf32518b.tar.bz2
spark-047a9d92caa1f0af2305e4afeba8339abf32518b.zip
[SPARK-18576][PYTHON] Add basic TaskContext information to PySpark
## What changes were proposed in this pull request? Adds basic TaskContext information to PySpark. ## How was this patch tested? New unit tests to `tests.py` & existing unit tests. Author: Holden Karau <holden@us.ibm.com> Closes #16211 from holdenk/SPARK-18576-pyspark-taskcontext.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/__init__.py5
-rw-r--r--python/pyspark/taskcontext.py90
-rw-r--r--python/pyspark/tests.py65
-rw-r--r--python/pyspark/worker.py6
4 files changed, 165 insertions, 1 deletions
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 5f93586a48..9331e74eed 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -34,6 +34,8 @@ Public classes:
Access files shipped with jobs.
- :class:`StorageLevel`:
Finer-grained cache persistence levels.
+ - :class:`TaskContext`:
+ Information about the current running task, avaialble on the workers and experimental.
"""
@@ -49,6 +51,7 @@ from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
from pyspark.status import *
+from pyspark.taskcontext import TaskContext
from pyspark.profiler import Profiler, BasicProfiler
from pyspark.version import __version__
@@ -106,5 +109,5 @@ from pyspark.sql import SQLContext, HiveContext, Row
__all__ = [
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
- "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler",
+ "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext",
]
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
new file mode 100644
index 0000000000..e5218d9e75
--- /dev/null
+++ b/python/pyspark/taskcontext.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+
+class TaskContext(object):
+
+ """
+ .. note:: Experimental
+
+ Contextual information about a task which can be read or mutated during
+ execution. To access the TaskContext for a running task, use:
+ L{TaskContext.get()}.
+ """
+
+ _taskContext = None
+
+ _attemptNumber = None
+ _partitionId = None
+ _stageId = None
+ _taskAttemptId = None
+
+ def __new__(cls):
+ """Even if users construct TaskContext instead of using get, give them the singleton."""
+ taskContext = cls._taskContext
+ if taskContext is not None:
+ return taskContext
+ cls._taskContext = taskContext = object.__new__(cls)
+ return taskContext
+
+ def __init__(self):
+ """Construct a TaskContext, use get instead"""
+ pass
+
+ @classmethod
+ def _getOrCreate(cls):
+ """Internal function to get or create global TaskContext."""
+ if cls._taskContext is None:
+ cls._taskContext = TaskContext()
+ return cls._taskContext
+
+ @classmethod
+ def get(cls):
+ """
+ Return the currently active TaskContext. This can be called inside of
+ user functions to access contextual information about running tasks.
+
+ .. note:: Must be called on the worker, not the driver. Returns None if not initialized.
+ """
+ return cls._taskContext
+
+ def stageId(self):
+ """The ID of the stage that this task belong to."""
+ return self._stageId
+
+ def partitionId(self):
+ """
+ The ID of the RDD partition that is computed by this task.
+ """
+ return self._partitionId
+
+ def attemptNumber(self):
+ """"
+ How many times this task has been attempted. The first task attempt will be assigned
+ attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
+ """
+ return self._attemptNumber
+
+ def taskAttemptId(self):
+ """
+ An ID that is unique to this task attempt (within the same SparkContext, no two task
+ attempts will share the same attempt ID). This is roughly equivalent to Hadoop's
+ TaskAttemptID.
+ """
+ return self._taskAttemptId
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index fe314c54a1..c383d9ab67 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -69,6 +69,7 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer,
from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
from pyspark import shuffle
from pyspark.profiler import BasicProfiler
+from pyspark.taskcontext import TaskContext
_have_scipy = False
_have_numpy = False
@@ -478,6 +479,70 @@ class AddFileTests(PySparkTestCase):
self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
+class TaskContextTests(PySparkTestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ # Allow retries even though they are normally disabled in local mode
+ self.sc = SparkContext('local[4, 2]', class_name)
+
+ def test_stage_id(self):
+ """Test the stage ids are available and incrementing as expected."""
+ rdd = self.sc.parallelize(range(10))
+ stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+ stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+ # Test using the constructor directly rather than the get()
+ stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
+ self.assertEqual(stage1 + 1, stage2)
+ self.assertEqual(stage1 + 2, stage3)
+ self.assertEqual(stage2 + 1, stage3)
+
+ def test_partition_id(self):
+ """Test the partition id."""
+ rdd1 = self.sc.parallelize(range(10), 1)
+ rdd2 = self.sc.parallelize(range(10), 2)
+ pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
+ pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
+ self.assertEqual(0, pids1[0])
+ self.assertEqual(0, pids1[9])
+ self.assertEqual(0, pids2[0])
+ self.assertEqual(1, pids2[9])
+
+ def test_attempt_number(self):
+ """Verify the attempt numbers are correctly reported."""
+ rdd = self.sc.parallelize(range(10))
+ # Verify a simple job with no failures
+ attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect()
+ map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
+
+ def fail_on_first(x):
+ """Fail on the first attempt so we get a positive attempt number"""
+ tc = TaskContext.get()
+ attempt_number = tc.attemptNumber()
+ partition_id = tc.partitionId()
+ attempt_id = tc.taskAttemptId()
+ if attempt_number == 0 and partition_id == 0:
+ raise Exception("Failing on first attempt")
+ else:
+ return [x, partition_id, attempt_number, attempt_id]
+ result = rdd.map(fail_on_first).collect()
+ # We should re-submit the first partition to it but other partitions should be attempt 0
+ self.assertEqual([0, 0, 1], result[0][0:3])
+ self.assertEqual([9, 3, 0], result[9][0:3])
+ first_partition = filter(lambda x: x[1] == 0, result)
+ map(lambda x: self.assertEqual(1, x[2]), first_partition)
+ other_partitions = filter(lambda x: x[1] != 0, result)
+ map(lambda x: self.assertEqual(0, x[2]), other_partitions)
+ # The task attempt id should be different
+ self.assertTrue(result[0][3] != result[9][3])
+
+ def test_tc_on_driver(self):
+ """Verify that getting the TaskContext on the driver returns None."""
+ tc = TaskContext.get()
+ self.assertTrue(tc is None)
+
+
class RDDTests(ReusedPySparkTestCase):
def test_range(self):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0918282953..25ee475c7f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,6 +27,7 @@ import traceback
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.taskcontext import TaskContext
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
@@ -125,6 +126,11 @@ def main(infile, outfile):
("%d.%d" % sys.version_info[:2], version))
# initialize global state
+ taskContext = TaskContext._getOrCreate()
+ taskContext._stageId = read_int(infile)
+ taskContext._partitionId = read_int(infile)
+ taskContext._attemptNumber = read_int(infile)
+ taskContext._taskAttemptId = read_long(infile)
shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
_accumulatorRegistry.clear()