aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/profiler.py
blob: d18daaabfcb3c4b6a62f4c379151c0c22697ff03 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import cProfile
import pstats
import os
import atexit

from pyspark.accumulators import AccumulatorParam


class ProfilerCollector(object):
    """
    This class keeps track of different profilers on a per
    stage basis. Also this is used to create new profilers for
    the different stages.
    """

    def __init__(self, profiler_cls, dump_path=None):
        self.profiler_cls = profiler_cls
        self.profile_dump_path = dump_path
        self.profilers = []

    def new_profiler(self, ctx):
        """ Create a new profiler using class `profiler_cls` """
        return self.profiler_cls(ctx)

    def add_profiler(self, id, profiler):
        """ Add a profiler for RDD `id` """
        if not self.profilers:
            if self.profile_dump_path:
                atexit.register(self.dump_profiles, self.profile_dump_path)
            else:
                atexit.register(self.show_profiles)

        self.profilers.append([id, profiler, False])

    def dump_profiles(self, path):
        """ Dump the profile stats into directory `path` """
        for id, profiler, _ in self.profilers:
            profiler.dump(id, path)
        self.profilers = []

    def show_profiles(self):
        """ Print the profile stats to stdout """
        for i, (id, profiler, showed) in enumerate(self.profilers):
            if not showed and profiler:
                profiler.show(id)
                # mark it as showed
                self.profilers[i][2] = True


class Profiler(object):
    """
    .. note:: DeveloperApi

    PySpark supports custom profilers, this is to allow for different profilers to
    be used as well as outputting to different formats than what is provided in the
    BasicProfiler.

    A custom profiler has to define or inherit the following methods:
        profile - will produce a system profile of some sort.
        stats - return the collected stats.
        dump - dumps the profiles to a path
        add - adds a profile to the existing accumulated profile

    The profiler class is chosen when creating a SparkContext

    >>> from pyspark import SparkConf, SparkContext
    >>> from pyspark import BasicProfiler
    >>> class MyCustomProfiler(BasicProfiler):
    ...     def show(self, id):
    ...         print("My custom profiles for RDD:%s" % id)
    ...
    >>> conf = SparkConf().set("spark.python.profile", "true")
    >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
    >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10)
    [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
    >>> sc.show_profiles()
    My custom profiles for RDD:1
    My custom profiles for RDD:2
    >>> sc.stop()
    """

    def __init__(self, ctx):
        pass

    def profile(self, func):
        """ Do profiling on the function `func`"""
        raise NotImplemented

    def stats(self):
        """ Return the collected profiling stats (pstats.Stats)"""
        raise NotImplemented

    def show(self, id):
        """ Print the profile stats to stdout, id is the RDD id """
        stats = self.stats()
        if stats:
            print("=" * 60)
            print("Profile of RDD<id=%d>" % id)
            print("=" * 60)
            stats.sort_stats("time", "cumulative").print_stats()

    def dump(self, id, path):
        """ Dump the profile into path, id is the RDD id """
        if not os.path.exists(path):
            os.makedirs(path)
        stats = self.stats()
        if stats:
            p = os.path.join(path, "rdd_%d.pstats" % id)
            stats.dump_stats(p)


class PStatsParam(AccumulatorParam):
    """PStatsParam is used to merge pstats.Stats"""

    @staticmethod
    def zero(value):
        return None

    @staticmethod
    def addInPlace(value1, value2):
        if value1 is None:
            return value2
        value1.add(value2)
        return value1


class BasicProfiler(Profiler):
    """
    BasicProfiler is the default profiler, which is implemented based on
    cProfile and Accumulator
    """
    def __init__(self, ctx):
        Profiler.__init__(self, ctx)
        # Creates a new accumulator for combining the profiles of different
        # partitions of a stage
        self._accumulator = ctx.accumulator(None, PStatsParam)

    def profile(self, func):
        """ Runs and profiles the method to_profile passed in. A profile object is returned. """
        pr = cProfile.Profile()
        pr.runcall(func)
        st = pstats.Stats(pr)
        st.stream = None  # make it picklable
        st.strip_dirs()

        # Adds a new profile to the existing accumulated value
        self._accumulator.add(st)

    def stats(self):
        return self._accumulator.value


if __name__ == "__main__":
    import doctest
    doctest.testmod()