diff options
-rw-r--r-- | docs/configuration.md | 3 | ||||
-rw-r--r-- | python/pyspark/__init__.py | 2 | ||||
-rw-r--r-- | python/pyspark/accumulators.py | 15 | ||||
-rw-r--r-- | python/pyspark/context.py | 46 | ||||
-rw-r--r-- | python/pyspark/profiler.py | 172 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 15 | ||||
-rw-r--r-- | python/pyspark/tests.py | 40 | ||||
-rw-r--r-- | python/pyspark/worker.py | 12 | ||||
-rwxr-xr-x | python/run-tests | 1 |
9 files changed, 235 insertions, 71 deletions
diff --git a/docs/configuration.md b/docs/configuration.md index 7c5b6d011c..e4e4b8d516 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -311,6 +311,9 @@ Apart from these, the following properties are also available, and may be useful or it will be displayed before the driver exiting. It also can be dumped into disk by `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, they will not be displayed automatically before driver exiting. + + By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by + passing a profiler class in as a parameter to the `SparkContext` constructor. </td> </tr> <tr> diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9556e4718e..d3efcdf221 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -45,6 +45,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.profiler import Profiler, BasicProfiler # for back compatibility from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row @@ -52,4 +53,5 @@ from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", + "Profiler", "BasicProfiler", ] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index b8cdbbe3cf..ccbca67656 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,21 +215,6 @@ FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) -class PStatsParam(AccumulatorParam): - """PStatsParam is used to merge pstats.Stats""" - - @staticmethod - def zero(value): - return None - - @staticmethod - def addInPlace(value1, value2): - if value1 is None: - return value2 - value1.add(value2) - return value1 - - class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 568e21f380..c0dec16ac1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import shutil import sys from threading import Lock from tempfile import NamedTemporaryFile -import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -33,6 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deseria from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call +from pyspark.profiler import ProfilerCollector, BasicProfiler from py4j.java_collections import ListConverter @@ -66,7 +66,7 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None, jsc=None): + gateway=None, jsc=None, profiler_cls=BasicProfiler): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -88,6 +88,9 @@ class SparkContext(object): :param conf: A L{SparkConf} object setting Spark properties. :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. + :param jsc: The JavaSparkContext instance (optional). + :param profiler_cls: A class of custom Profiler used to do profiling + (default is pyspark.profiler.BasicProfiler). >>> from pyspark.context import SparkContext @@ -102,14 +105,14 @@ class SparkContext(object): SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc) + conf, jsc, profiler_cls) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc): + conf, jsc, profiler_cls): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -192,7 +195,11 @@ class SparkContext(object): self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() # profiling stats collected for each PythonRDD - self._profile_stats = [] + if self._conf.get("spark.python.profile", "false") == "true": + dump_path = self._conf.get("spark.python.profile.dump", None) + self.profiler_collector = ProfilerCollector(profiler_cls, dump_path) + else: + self.profiler_collector = None def _initialize_context(self, jconf): """ @@ -826,39 +833,14 @@ class SparkContext(object): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) - def _add_profile(self, id, profileAcc): - if not self._profile_stats: - dump_path = self._conf.get("spark.python.profile.dump") - if dump_path: - atexit.register(self.dump_profiles, dump_path) - else: - atexit.register(self.show_profiles) - - self._profile_stats.append([id, profileAcc, False]) - def show_profiles(self): """ Print the profile stats to stdout """ - for i, (id, acc, showed) in enumerate(self._profile_stats): - stats = acc.value - if not showed and stats: - print "=" * 60 - print "Profile of RDD<id=%d>" % id - print "=" * 60 - stats.sort_stats("time", "cumulative").print_stats() - # mark it as showed - self._profile_stats[i][2] = True + self.profiler_collector.show_profiles() def dump_profiles(self, path): """ Dump the profile stats into directory `path` """ - if not os.path.exists(path): - os.makedirs(path) - for id, acc, _ in self._profile_stats: - stats = acc.value - if stats: - p = os.path.join(path, "rdd_%d.pstats" % id) - stats.dump_stats(p) - self._profile_stats = [] + self.profiler_collector.dump_profiles(path) def _test(): diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py new file mode 100644 index 0000000000..4408996db0 --- /dev/null +++ b/python/pyspark/profiler.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import cProfile +import pstats +import os +import atexit + +from pyspark.accumulators import AccumulatorParam + + +class ProfilerCollector(object): + """ + This class keeps track of different profilers on a per + stage basis. Also this is used to create new profilers for + the different stages. + """ + + def __init__(self, profiler_cls, dump_path=None): + self.profiler_cls = profiler_cls + self.profile_dump_path = dump_path + self.profilers = [] + + def new_profiler(self, ctx): + """ Create a new profiler using class `profiler_cls` """ + return self.profiler_cls(ctx) + + def add_profiler(self, id, profiler): + """ Add a profiler for RDD `id` """ + if not self.profilers: + if self.profile_dump_path: + atexit.register(self.dump_profiles, self.profile_dump_path) + else: + atexit.register(self.show_profiles) + + self.profilers.append([id, profiler, False]) + + def dump_profiles(self, path): + """ Dump the profile stats into directory `path` """ + for id, profiler, _ in self.profilers: + profiler.dump(id, path) + self.profilers = [] + + def show_profiles(self): + """ Print the profile stats to stdout """ + for i, (id, profiler, showed) in enumerate(self.profilers): + if not showed and profiler: + profiler.show(id) + # mark it as showed + self.profilers[i][2] = True + + +class Profiler(object): + """ + .. note:: DeveloperApi + + PySpark supports custom profilers, this is to allow for different profilers to + be used as well as outputting to different formats than what is provided in the + BasicProfiler. + + A custom profiler has to define or inherit the following methods: + profile - will produce a system profile of some sort. + stats - return the collected stats. + dump - dumps the profiles to a path + add - adds a profile to the existing accumulated profile + + The profiler class is chosen when creating a SparkContext + + >>> from pyspark import SparkConf, SparkContext + >>> from pyspark import BasicProfiler + >>> class MyCustomProfiler(BasicProfiler): + ... def show(self, id): + ... print "My custom profiles for RDD:%s" % id + ... + >>> conf = SparkConf().set("spark.python.profile", "true") + >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) + >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) + [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.show_profiles() + My custom profiles for RDD:1 + My custom profiles for RDD:2 + >>> sc.stop() + """ + + def __init__(self, ctx): + pass + + def profile(self, func): + """ Do profiling on the function `func`""" + raise NotImplemented + + def stats(self): + """ Return the collected profiling stats (pstats.Stats)""" + raise NotImplemented + + def show(self, id): + """ Print the profile stats to stdout, id is the RDD id """ + stats = self.stats() + if stats: + print "=" * 60 + print "Profile of RDD<id=%d>" % id + print "=" * 60 + stats.sort_stats("time", "cumulative").print_stats() + + def dump(self, id, path): + """ Dump the profile into path, id is the RDD id """ + if not os.path.exists(path): + os.makedirs(path) + stats = self.stats() + if stats: + p = os.path.join(path, "rdd_%d.pstats" % id) + stats.dump_stats(p) + + +class PStatsParam(AccumulatorParam): + """PStatsParam is used to merge pstats.Stats""" + + @staticmethod + def zero(value): + return None + + @staticmethod + def addInPlace(value1, value2): + if value1 is None: + return value2 + value1.add(value2) + return value1 + + +class BasicProfiler(Profiler): + """ + BasicProfiler is the default profiler, which is implemented based on + cProfile and Accumulator + """ + def __init__(self, ctx): + Profiler.__init__(self, ctx) + # Creates a new accumulator for combining the profiles of different + # partitions of a stage + self._accumulator = ctx.accumulator(None, PStatsParam) + + def profile(self, func): + """ Runs and profiles the method to_profile passed in. A profile object is returned. """ + pr = cProfile.Profile() + pr.runcall(func) + st = pstats.Stats(pr) + st.stream = None # make it picklable + st.strip_dirs() + + # Adds a new profile to the existing accumulated value + self._accumulator.add(st) + + def stats(self): + return self._accumulator.value + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 014c0aa889..b6dd5a3bf0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -31,7 +31,6 @@ import bisect import random from math import sqrt, log, isinf, isnan -from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -2132,9 +2131,13 @@ class PipelinedRDD(RDD): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" - profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None - command = (self.func, profileStats, self._prev_jrdd_deserializer, + + if self.ctx.profiler_collector: + profiler = self.ctx.profiler_collector.new_profiler(self.ctx) + else: + profiler = None + + command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() @@ -2157,9 +2160,9 @@ class PipelinedRDD(RDD): broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() - if enable_profile: + if profiler: self._id = self._jrdd_val.id() - self.ctx._add_profile(self._id, profileStats) + self.ctx.profiler_collector.add_profiler(self._id, profiler) return self._jrdd_val def id(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e694ffcff5..081a77fbb0 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, External from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType from pyspark import shuffle +from pyspark.profiler import BasicProfiler _have_scipy = False _have_numpy = False @@ -743,16 +744,12 @@ class ProfilerTests(PySparkTestCase): self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): + self.do_computation() - def heavy_foo(x): - for i in range(1 << 20): - x = 1 - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - profiles = self.sc._profile_stats - self.assertEqual(1, len(profiles)) - id, acc, _ = profiles[0] - stats = acc.value + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + stats = profiler.stats() self.assertTrue(stats is not None) width, stat_list = stats.get_print_list([]) func_names = [func_name for fname, n, func_name in stat_list] @@ -763,6 +760,31 @@ class ProfilerTests(PySparkTestCase): self.sc.dump_profiles(d) self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show(self, id): + self.result = "Custom formatting" + + self.sc.profiler_collector.profiler_cls = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) + + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 20): + x = 1 + + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + class ExamplePointUDT(UserDefinedType): """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7e5343c973..8a93c320ec 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,8 +23,6 @@ import sys import time import socket import traceback -import cProfile -import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -90,19 +88,15 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, stats, deserializer, serializer) = command + (func, profiler, deserializer, serializer) = command init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) - if stats: - p = cProfile.Profile() - p.runcall(process) - st = pstats.Stats(p) - st.stream = None # make it picklable - stats.add(st.strip_dirs()) + if profiler: + profiler.profile(process) else: process() except Exception: diff --git a/python/run-tests b/python/run-tests index 9ee19ed6e6..53c34557d9 100755 --- a/python/run-tests +++ b/python/run-tests @@ -57,6 +57,7 @@ function run_core_tests() { PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" + run_test "pyspark/profiler.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } |