diff options
author | Holden Karau <holden@us.ibm.com> | 2016-12-20 15:51:21 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-12-20 15:51:21 -0800 |
commit | 047a9d92caa1f0af2305e4afeba8339abf32518b (patch) | |
tree | ae8588410abee1171dc0cf18ffcbd71a99f62399 | |
parent | caed89321fdabe83e46451ca4e968f86481ad500 (diff) | |
download | spark-047a9d92caa1f0af2305e4afeba8339abf32518b.tar.gz spark-047a9d92caa1f0af2305e4afeba8339abf32518b.tar.bz2 spark-047a9d92caa1f0af2305e4afeba8339abf32518b.zip |
[SPARK-18576][PYTHON] Add basic TaskContext information to PySpark
## What changes were proposed in this pull request?
Adds basic TaskContext information to PySpark.
## How was this patch tested?
New unit tests to `tests.py` & existing unit tests.
Author: Holden Karau <holden@us.ibm.com>
Closes #16211 from holdenk/SPARK-18576-pyspark-taskcontext.
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 5 | ||||
-rw-r--r-- | python/pyspark/__init__.py | 5 | ||||
-rw-r--r-- | python/pyspark/taskcontext.py | 90 | ||||
-rw-r--r-- | python/pyspark/tests.py | 65 | ||||
-rw-r--r-- | python/pyspark/worker.py | 6 |
5 files changed, 170 insertions, 1 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0ca91b9bf8..04ae97ed3c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -275,6 +275,11 @@ private[spark] class PythonRunner( dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) + // Write out the TaskContextInfo + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 5f93586a48..9331e74eed 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -34,6 +34,8 @@ Public classes: Access files shipped with jobs. - :class:`StorageLevel`: Finer-grained cache persistence levels. + - :class:`TaskContext`: + Information about the current running task, avaialble on the workers and experimental. """ @@ -49,6 +51,7 @@ from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer from pyspark.status import * +from pyspark.taskcontext import TaskContext from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ @@ -106,5 +109,5 @@ from pyspark.sql import SQLContext, HiveContext, Row __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", - "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", + "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext", ] diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py new file mode 100644 index 0000000000..e5218d9e75 --- /dev/null +++ b/python/pyspark/taskcontext.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + + +class TaskContext(object): + + """ + .. note:: Experimental + + Contextual information about a task which can be read or mutated during + execution. To access the TaskContext for a running task, use: + L{TaskContext.get()}. + """ + + _taskContext = None + + _attemptNumber = None + _partitionId = None + _stageId = None + _taskAttemptId = None + + def __new__(cls): + """Even if users construct TaskContext instead of using get, give them the singleton.""" + taskContext = cls._taskContext + if taskContext is not None: + return taskContext + cls._taskContext = taskContext = object.__new__(cls) + return taskContext + + def __init__(self): + """Construct a TaskContext, use get instead""" + pass + + @classmethod + def _getOrCreate(cls): + """Internal function to get or create global TaskContext.""" + if cls._taskContext is None: + cls._taskContext = TaskContext() + return cls._taskContext + + @classmethod + def get(cls): + """ + Return the currently active TaskContext. This can be called inside of + user functions to access contextual information about running tasks. + + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. + """ + return cls._taskContext + + def stageId(self): + """The ID of the stage that this task belong to.""" + return self._stageId + + def partitionId(self): + """ + The ID of the RDD partition that is computed by this task. + """ + return self._partitionId + + def attemptNumber(self): + """" + How many times this task has been attempted. The first task attempt will be assigned + attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. + """ + return self._attemptNumber + + def taskAttemptId(self): + """ + An ID that is unique to this task attempt (within the same SparkContext, no two task + attempts will share the same attempt ID). This is roughly equivalent to Hadoop's + TaskAttemptID. + """ + return self._taskAttemptId diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index fe314c54a1..c383d9ab67 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -69,6 +69,7 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler +from pyspark.taskcontext import TaskContext _have_scipy = False _have_numpy = False @@ -478,6 +479,70 @@ class AddFileTests(PySparkTestCase): self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) +class TaskContextTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + # Allow retries even though they are normally disabled in local mode + self.sc = SparkContext('local[4, 2]', class_name) + + def test_stage_id(self): + """Test the stage ids are available and incrementing as expected.""" + rdd = self.sc.parallelize(range(10)) + stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + # Test using the constructor directly rather than the get() + stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] + self.assertEqual(stage1 + 1, stage2) + self.assertEqual(stage1 + 2, stage3) + self.assertEqual(stage2 + 1, stage3) + + def test_partition_id(self): + """Test the partition id.""" + rdd1 = self.sc.parallelize(range(10), 1) + rdd2 = self.sc.parallelize(range(10), 2) + pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() + pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() + self.assertEqual(0, pids1[0]) + self.assertEqual(0, pids1[9]) + self.assertEqual(0, pids2[0]) + self.assertEqual(1, pids2[9]) + + def test_attempt_number(self): + """Verify the attempt numbers are correctly reported.""" + rdd = self.sc.parallelize(range(10)) + # Verify a simple job with no failures + attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() + map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) + + def fail_on_first(x): + """Fail on the first attempt so we get a positive attempt number""" + tc = TaskContext.get() + attempt_number = tc.attemptNumber() + partition_id = tc.partitionId() + attempt_id = tc.taskAttemptId() + if attempt_number == 0 and partition_id == 0: + raise Exception("Failing on first attempt") + else: + return [x, partition_id, attempt_number, attempt_id] + result = rdd.map(fail_on_first).collect() + # We should re-submit the first partition to it but other partitions should be attempt 0 + self.assertEqual([0, 0, 1], result[0][0:3]) + self.assertEqual([9, 3, 0], result[9][0:3]) + first_partition = filter(lambda x: x[1] == 0, result) + map(lambda x: self.assertEqual(1, x[2]), first_partition) + other_partitions = filter(lambda x: x[1] != 0, result) + map(lambda x: self.assertEqual(0, x[2]), other_partitions) + # The task attempt id should be different + self.assertTrue(result[0][3] != result[9][3]) + + def test_tc_on_driver(self): + """Verify that getting the TaskContext on the driver returns None.""" + tc = TaskContext.get() + self.assertTrue(tc is None) + + class RDDTests(ReusedPySparkTestCase): def test_range(self): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0918282953..25ee475c7f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,6 +27,7 @@ import traceback from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer @@ -125,6 +126,11 @@ def main(infile, outfile): ("%d.%d" % sys.version_info[:2], version)) # initialize global state + taskContext = TaskContext._getOrCreate() + taskContext._stageId = read_int(infile) + taskContext._partitionId = read_int(infile) + taskContext._attemptNumber = read_int(infile) + taskContext._taskAttemptId = read_long(infile) shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() |