aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-17 13:36:43 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-02-17 13:36:43 -0800
commit445a755b884885b88c1778fd56a3151045b0b0ed (patch)
treee36607b0aedc8040fa1946f364ceba85aadbcf68
parentde4836f8f12c36c1b350cef288a75b5e59155735 (diff)
downloadspark-445a755b884885b88c1778fd56a3151045b0b0ed.tar.gz
spark-445a755b884885b88c1778fd56a3151045b0b0ed.tar.bz2
spark-445a755b884885b88c1778fd56a3151045b0b0ed.zip
[SPARK-4172] [PySpark] Progress API in Python
This patch bring the pull based progress API into Python, also a example in Python. Author: Davies Liu <davies@databricks.com> Closes #3027 from davies/progress_api and squashes the following commits: b1ba984 [Davies Liu] fix style d3b9253 [Davies Liu] add tests, mute the exception after stop 4297327 [Davies Liu] Merge branch 'master' of github.com:apache/spark into progress_api 969fa9d [Davies Liu] Merge branch 'master' of github.com:apache/spark into progress_api 25590c9 [Davies Liu] update with Java API 360de2d [Davies Liu] Merge branch 'master' of github.com:apache/spark into progress_api c0f1021 [Davies Liu] Merge branch 'master' of github.com:apache/spark into progress_api 023afb3 [Davies Liu] add Python API and example for progress API
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala40
-rw-r--r--examples/src/main/python/status_api_demo.py67
-rw-r--r--python/pyspark/__init__.py15
-rw-r--r--python/pyspark/context.py7
-rw-r--r--python/pyspark/status.py96
-rw-r--r--python/pyspark/tests.py31
6 files changed, 232 insertions, 24 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 774f3d8cdb..3938580aee 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import java.util.concurrent.RejectedExecutionException
import scala.language.existentials
import scala.util.control.NonFatal
@@ -95,25 +96,30 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer) {
var reason : TaskEndReason = UnknownReason
- getTaskResultExecutor.execute(new Runnable {
- override def run(): Unit = Utils.logUncaughtExceptions {
- try {
- if (serializedData != null && serializedData.limit() > 0) {
- reason = serializer.get().deserialize[TaskEndReason](
- serializedData, Utils.getSparkClassLoader)
+ try {
+ getTaskResultExecutor.execute(new Runnable {
+ override def run(): Unit = Utils.logUncaughtExceptions {
+ try {
+ if (serializedData != null && serializedData.limit() > 0) {
+ reason = serializer.get().deserialize[TaskEndReason](
+ serializedData, Utils.getSparkClassLoader)
+ }
+ } catch {
+ case cnd: ClassNotFoundException =>
+ // Log an error but keep going here -- the task failed, so not catastrophic
+ // if we can't deserialize the reason.
+ val loader = Utils.getContextOrSparkClassLoader
+ logError(
+ "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
+ case ex: Exception => {}
}
- } catch {
- case cnd: ClassNotFoundException =>
- // Log an error but keep going here -- the task failed, so not catastrophic if we can't
- // deserialize the reason.
- val loader = Utils.getContextOrSparkClassLoader
- logError(
- "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
- case ex: Exception => {}
+ scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
- scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
- }
- })
+ })
+ } catch {
+ case e: RejectedExecutionException if sparkEnv.isStopped =>
+ // ignore it
+ }
}
def stop() {
diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py
new file mode 100644
index 0000000000..a33bdc475a
--- /dev/null
+++ b/examples/src/main/python/status_api_demo.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+import threading
+import Queue
+
+from pyspark import SparkConf, SparkContext
+
+
+def delayed(seconds):
+ def f(x):
+ time.sleep(seconds)
+ return x
+ return f
+
+
+def call_in_background(f, *args):
+ result = Queue.Queue(1)
+ t = threading.Thread(target=lambda: result.put(f(*args)))
+ t.daemon = True
+ t.start()
+ return result
+
+
+def main():
+ conf = SparkConf().set("spark.ui.showConsoleProgress", "false")
+ sc = SparkContext(appName="PythonStatusAPIDemo", conf=conf)
+
+ def run():
+ rdd = sc.parallelize(range(10), 10).map(delayed(2))
+ reduced = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
+ return reduced.map(delayed(2)).collect()
+
+ result = call_in_background(run)
+ status = sc.statusTracker()
+ while result.empty():
+ ids = status.getJobIdsForGroup()
+ for id in ids:
+ job = status.getJobInfo(id)
+ print "Job", id, "status: ", job.status
+ for sid in job.stageIds:
+ info = status.getStageInfo(sid)
+ if info:
+ print "Stage %d: %d tasks total (%d active, %d complete)" % \
+ (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)
+ time.sleep(1)
+
+ print "Job results are:", result.get()
+ sc.stop()
+
+if __name__ == "__main__":
+ main()
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index d3efcdf221..5f70ac6ed8 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -22,17 +22,17 @@ Public classes:
- :class:`SparkContext`:
Main entry point for Spark functionality.
- - L{RDD}
+ - :class:`RDD`:
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
- - L{Broadcast}
+ - :class:`Broadcast`:
A broadcast variable that gets reused across tasks.
- - L{Accumulator}
+ - :class:`Accumulator`:
An "add-only" shared variable that tasks can only add values to.
- - L{SparkConf}
+ - :class:`SparkConf`:
For configuring Spark.
- - L{SparkFiles}
+ - :class:`SparkFiles`:
Access files shipped with jobs.
- - L{StorageLevel}
+ - :class:`StorageLevel`:
Finer-grained cache persistence levels.
"""
@@ -45,6 +45,7 @@ from pyspark.storagelevel import StorageLevel
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
+from pyspark.status import *
from pyspark.profiler import Profiler, BasicProfiler
# for back compatibility
@@ -53,5 +54,5 @@ from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
__all__ = [
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
- "Profiler", "BasicProfiler",
+ "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler",
]
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 40b3152b23..6011caf9f1 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -32,6 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deseria
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
+from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
from py4j.java_collections import ListConverter
@@ -810,6 +811,12 @@ class SparkContext(object):
"""
self._jsc.sc().cancelAllJobs()
+ def statusTracker(self):
+ """
+ Return :class:`StatusTracker` object
+ """
+ return StatusTracker(self._jsc.statusTracker())
+
def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
"""
Executes the given partitionFunc on the specified set of partitions,
diff --git a/python/pyspark/status.py b/python/pyspark/status.py
new file mode 100644
index 0000000000..a6fa7dd314
--- /dev/null
+++ b/python/pyspark/status.py
@@ -0,0 +1,96 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import namedtuple
+
+__all__ = ["SparkJobInfo", "SparkStageInfo", "StatusTracker"]
+
+
+class SparkJobInfo(namedtuple("SparkJobInfo", "jobId stageIds status")):
+ """
+ Exposes information about Spark Jobs.
+ """
+
+
+class SparkStageInfo(namedtuple("SparkStageInfo",
+ "stageId currentAttemptId name numTasks numActiveTasks "
+ "numCompletedTasks numFailedTasks")):
+ """
+ Exposes information about Spark Stages.
+ """
+
+
+class StatusTracker(object):
+ """
+ Low-level status reporting APIs for monitoring job and stage progress.
+
+ These APIs intentionally provide very weak consistency semantics;
+ consumers of these APIs should be prepared to handle empty / missing
+ information. For example, a job's stage ids may be known but the status
+ API may not have any information about the details of those stages, so
+ `getStageInfo` could potentially return `None` for a valid stage id.
+
+ To limit memory usage, these APIs only provide information on recent
+ jobs / stages. These APIs will provide information for the last
+ `spark.ui.retainedStages` stages and `spark.ui.retainedJobs` jobs.
+ """
+ def __init__(self, jtracker):
+ self._jtracker = jtracker
+
+ def getJobIdsForGroup(self, jobGroup=None):
+ """
+ Return a list of all known jobs in a particular job group. If
+ `jobGroup` is None, then returns all known jobs that are not
+ associated with a job group.
+
+ The returned list may contain running, failed, and completed jobs,
+ and may vary across invocations of this method. This method does
+ not guarantee the order of the elements in its result.
+ """
+ return list(self._jtracker.getJobIdsForGroup(jobGroup))
+
+ def getActiveStageIds(self):
+ """
+ Returns an array containing the ids of all active stages.
+ """
+ return sorted(list(self._jtracker.getActiveStageIds()))
+
+ def getActiveJobsIds(self):
+ """
+ Returns an array containing the ids of all active jobs.
+ """
+ return sorted((list(self._jtracker.getActiveJobIds())))
+
+ def getJobInfo(self, jobId):
+ """
+ Returns a :class:`SparkJobInfo` object, or None if the job info
+ could not be found or was garbage collected.
+ """
+ job = self._jtracker.getJobInfo(jobId)
+ if job is not None:
+ return SparkJobInfo(jobId, job.stageIds(), str(job.status()))
+
+ def getStageInfo(self, stageId):
+ """
+ Returns a :class:`SparkStageInfo` object, or None if the stage
+ info could not be found or was garbage collected.
+ """
+ stage = self._jtracker.getStageInfo(stageId)
+ if stage is not None:
+ # TODO: fetch them in batch for better performance
+ attrs = [getattr(stage, f)() for f in SparkStageInfo._fields[1:]]
+ return SparkStageInfo(stageId, *attrs)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b5e28c4980..d6afc1cdaa 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1550,6 +1550,37 @@ class ContextTests(unittest.TestCase):
sc.stop()
self.assertEqual(SparkContext._active_spark_context, None)
+ def test_progress_api(self):
+ with SparkContext() as sc:
+ sc.setJobGroup('test_progress_api', '', True)
+
+ rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
+ t = threading.Thread(target=rdd.collect)
+ t.daemon = True
+ t.start()
+ # wait for scheduler to start
+ time.sleep(1)
+
+ tracker = sc.statusTracker()
+ jobIds = tracker.getJobIdsForGroup('test_progress_api')
+ self.assertEqual(1, len(jobIds))
+ job = tracker.getJobInfo(jobIds[0])
+ self.assertEqual(1, len(job.stageIds))
+ stage = tracker.getStageInfo(job.stageIds[0])
+ self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
+
+ sc.cancelAllJobs()
+ t.join()
+ # wait for event listener to update the status
+ time.sleep(1)
+
+ job = tracker.getJobInfo(jobIds[0])
+ self.assertEqual('FAILED', job.status)
+ self.assertEqual([], tracker.getActiveJobsIds())
+ self.assertEqual([], tracker.getActiveStageIds())
+
+ sc.stop()
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):