aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/configuration.md3
-rw-r--r--python/pyspark/__init__.py2
-rw-r--r--python/pyspark/accumulators.py15
-rw-r--r--python/pyspark/context.py46
-rw-r--r--python/pyspark/profiler.py172
-rw-r--r--python/pyspark/rdd.py15
-rw-r--r--python/pyspark/tests.py40
-rw-r--r--python/pyspark/worker.py12
-rwxr-xr-xpython/run-tests1
9 files changed, 235 insertions, 71 deletions
diff --git a/docs/configuration.md b/docs/configuration.md
index 7c5b6d011c..e4e4b8d516 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -311,6 +311,9 @@ Apart from these, the following properties are also available, and may be useful
or it will be displayed before the driver exiting. It also can be dumped into disk by
`sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
they will not be displayed automatically before driver exiting.
+
+ By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by
+ passing a profiler class in as a parameter to the `SparkContext` constructor.
</td>
</tr>
<tr>
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 9556e4718e..d3efcdf221 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -45,6 +45,7 @@ from pyspark.storagelevel import StorageLevel
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
+from pyspark.profiler import Profiler, BasicProfiler
# for back compatibility
from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
@@ -52,4 +53,5 @@ from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
__all__ = [
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
+ "Profiler", "BasicProfiler",
]
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index b8cdbbe3cf..ccbca67656 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -215,21 +215,6 @@ FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
-class PStatsParam(AccumulatorParam):
- """PStatsParam is used to merge pstats.Stats"""
-
- @staticmethod
- def zero(value):
- return None
-
- @staticmethod
- def addInPlace(value1, value2):
- if value1 is None:
- return value2
- value1.add(value2)
- return value1
-
-
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
"""
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 568e21f380..c0dec16ac1 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -20,7 +20,6 @@ import shutil
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
-import atexit
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@@ -33,6 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deseria
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
+from pyspark.profiler import ProfilerCollector, BasicProfiler
from py4j.java_collections import ListConverter
@@ -66,7 +66,7 @@ class SparkContext(object):
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
- gateway=None, jsc=None):
+ gateway=None, jsc=None, profiler_cls=BasicProfiler):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
@@ -88,6 +88,9 @@ class SparkContext(object):
:param conf: A L{SparkConf} object setting Spark properties.
:param gateway: Use an existing gateway and JVM, otherwise a new JVM
will be instantiated.
+ :param jsc: The JavaSparkContext instance (optional).
+ :param profiler_cls: A class of custom Profiler used to do profiling
+ (default is pyspark.profiler.BasicProfiler).
>>> from pyspark.context import SparkContext
@@ -102,14 +105,14 @@ class SparkContext(object):
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf, jsc)
+ conf, jsc, profiler_cls)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf, jsc):
+ conf, jsc, profiler_cls):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
@@ -192,7 +195,11 @@ class SparkContext(object):
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
# profiling stats collected for each PythonRDD
- self._profile_stats = []
+ if self._conf.get("spark.python.profile", "false") == "true":
+ dump_path = self._conf.get("spark.python.profile.dump", None)
+ self.profiler_collector = ProfilerCollector(profiler_cls, dump_path)
+ else:
+ self.profiler_collector = None
def _initialize_context(self, jconf):
"""
@@ -826,39 +833,14 @@ class SparkContext(object):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
- def _add_profile(self, id, profileAcc):
- if not self._profile_stats:
- dump_path = self._conf.get("spark.python.profile.dump")
- if dump_path:
- atexit.register(self.dump_profiles, dump_path)
- else:
- atexit.register(self.show_profiles)
-
- self._profile_stats.append([id, profileAcc, False])
-
def show_profiles(self):
""" Print the profile stats to stdout """
- for i, (id, acc, showed) in enumerate(self._profile_stats):
- stats = acc.value
- if not showed and stats:
- print "=" * 60
- print "Profile of RDD<id=%d>" % id
- print "=" * 60
- stats.sort_stats("time", "cumulative").print_stats()
- # mark it as showed
- self._profile_stats[i][2] = True
+ self.profiler_collector.show_profiles()
def dump_profiles(self, path):
""" Dump the profile stats into directory `path`
"""
- if not os.path.exists(path):
- os.makedirs(path)
- for id, acc, _ in self._profile_stats:
- stats = acc.value
- if stats:
- p = os.path.join(path, "rdd_%d.pstats" % id)
- stats.dump_stats(p)
- self._profile_stats = []
+ self.profiler_collector.dump_profiles(path)
def _test():
diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
new file mode 100644
index 0000000000..4408996db0
--- /dev/null
+++ b/python/pyspark/profiler.py
@@ -0,0 +1,172 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import cProfile
+import pstats
+import os
+import atexit
+
+from pyspark.accumulators import AccumulatorParam
+
+
+class ProfilerCollector(object):
+ """
+ This class keeps track of different profilers on a per
+ stage basis. Also this is used to create new profilers for
+ the different stages.
+ """
+
+ def __init__(self, profiler_cls, dump_path=None):
+ self.profiler_cls = profiler_cls
+ self.profile_dump_path = dump_path
+ self.profilers = []
+
+ def new_profiler(self, ctx):
+ """ Create a new profiler using class `profiler_cls` """
+ return self.profiler_cls(ctx)
+
+ def add_profiler(self, id, profiler):
+ """ Add a profiler for RDD `id` """
+ if not self.profilers:
+ if self.profile_dump_path:
+ atexit.register(self.dump_profiles, self.profile_dump_path)
+ else:
+ atexit.register(self.show_profiles)
+
+ self.profilers.append([id, profiler, False])
+
+ def dump_profiles(self, path):
+ """ Dump the profile stats into directory `path` """
+ for id, profiler, _ in self.profilers:
+ profiler.dump(id, path)
+ self.profilers = []
+
+ def show_profiles(self):
+ """ Print the profile stats to stdout """
+ for i, (id, profiler, showed) in enumerate(self.profilers):
+ if not showed and profiler:
+ profiler.show(id)
+ # mark it as showed
+ self.profilers[i][2] = True
+
+
+class Profiler(object):
+ """
+ .. note:: DeveloperApi
+
+ PySpark supports custom profilers, this is to allow for different profilers to
+ be used as well as outputting to different formats than what is provided in the
+ BasicProfiler.
+
+ A custom profiler has to define or inherit the following methods:
+ profile - will produce a system profile of some sort.
+ stats - return the collected stats.
+ dump - dumps the profiles to a path
+ add - adds a profile to the existing accumulated profile
+
+ The profiler class is chosen when creating a SparkContext
+
+ >>> from pyspark import SparkConf, SparkContext
+ >>> from pyspark import BasicProfiler
+ >>> class MyCustomProfiler(BasicProfiler):
+ ... def show(self, id):
+ ... print "My custom profiles for RDD:%s" % id
+ ...
+ >>> conf = SparkConf().set("spark.python.profile", "true")
+ >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
+ >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+ >>> sc.show_profiles()
+ My custom profiles for RDD:1
+ My custom profiles for RDD:2
+ >>> sc.stop()
+ """
+
+ def __init__(self, ctx):
+ pass
+
+ def profile(self, func):
+ """ Do profiling on the function `func`"""
+ raise NotImplemented
+
+ def stats(self):
+ """ Return the collected profiling stats (pstats.Stats)"""
+ raise NotImplemented
+
+ def show(self, id):
+ """ Print the profile stats to stdout, id is the RDD id """
+ stats = self.stats()
+ if stats:
+ print "=" * 60
+ print "Profile of RDD<id=%d>" % id
+ print "=" * 60
+ stats.sort_stats("time", "cumulative").print_stats()
+
+ def dump(self, id, path):
+ """ Dump the profile into path, id is the RDD id """
+ if not os.path.exists(path):
+ os.makedirs(path)
+ stats = self.stats()
+ if stats:
+ p = os.path.join(path, "rdd_%d.pstats" % id)
+ stats.dump_stats(p)
+
+
+class PStatsParam(AccumulatorParam):
+ """PStatsParam is used to merge pstats.Stats"""
+
+ @staticmethod
+ def zero(value):
+ return None
+
+ @staticmethod
+ def addInPlace(value1, value2):
+ if value1 is None:
+ return value2
+ value1.add(value2)
+ return value1
+
+
+class BasicProfiler(Profiler):
+ """
+ BasicProfiler is the default profiler, which is implemented based on
+ cProfile and Accumulator
+ """
+ def __init__(self, ctx):
+ Profiler.__init__(self, ctx)
+ # Creates a new accumulator for combining the profiles of different
+ # partitions of a stage
+ self._accumulator = ctx.accumulator(None, PStatsParam)
+
+ def profile(self, func):
+ """ Runs and profiles the method to_profile passed in. A profile object is returned. """
+ pr = cProfile.Profile()
+ pr.runcall(func)
+ st = pstats.Stats(pr)
+ st.stream = None # make it picklable
+ st.strip_dirs()
+
+ # Adds a new profile to the existing accumulated value
+ self._accumulator.add(st)
+
+ def stats(self):
+ return self._accumulator.value
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 014c0aa889..b6dd5a3bf0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -31,7 +31,6 @@ import bisect
import random
from math import sqrt, log, isinf, isnan
-from pyspark.accumulators import PStatsParam
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
@@ -2132,9 +2131,13 @@ class PipelinedRDD(RDD):
return self._jrdd_val
if self._bypass_serializer:
self._jrdd_deserializer = NoOpSerializer()
- enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
- profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
- command = (self.func, profileStats, self._prev_jrdd_deserializer,
+
+ if self.ctx.profiler_collector:
+ profiler = self.ctx.profiler_collector.new_profiler(self.ctx)
+ else:
+ profiler = None
+
+ command = (self.func, profiler, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
@@ -2157,9 +2160,9 @@ class PipelinedRDD(RDD):
broadcast_vars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
- if enable_profile:
+ if profiler:
self._id = self._jrdd_val.id()
- self.ctx._add_profile(self._id, profileStats)
+ self.ctx.profiler_collector.add_profiler(self._id, profiler)
return self._jrdd_val
def id(self):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index e694ffcff5..081a77fbb0 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, External
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
from pyspark import shuffle
+from pyspark.profiler import BasicProfiler
_have_scipy = False
_have_numpy = False
@@ -743,16 +744,12 @@ class ProfilerTests(PySparkTestCase):
self.sc = SparkContext('local[4]', class_name, conf=conf)
def test_profiler(self):
+ self.do_computation()
- def heavy_foo(x):
- for i in range(1 << 20):
- x = 1
- rdd = self.sc.parallelize(range(100))
- rdd.foreach(heavy_foo)
- profiles = self.sc._profile_stats
- self.assertEqual(1, len(profiles))
- id, acc, _ = profiles[0]
- stats = acc.value
+ profilers = self.sc.profiler_collector.profilers
+ self.assertEqual(1, len(profilers))
+ id, profiler, _ = profilers[0]
+ stats = profiler.stats()
self.assertTrue(stats is not None)
width, stat_list = stats.get_print_list([])
func_names = [func_name for fname, n, func_name in stat_list]
@@ -763,6 +760,31 @@ class ProfilerTests(PySparkTestCase):
self.sc.dump_profiles(d)
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
+ def test_custom_profiler(self):
+ class TestCustomProfiler(BasicProfiler):
+ def show(self, id):
+ self.result = "Custom formatting"
+
+ self.sc.profiler_collector.profiler_cls = TestCustomProfiler
+
+ self.do_computation()
+
+ profilers = self.sc.profiler_collector.profilers
+ self.assertEqual(1, len(profilers))
+ _, profiler, _ = profilers[0]
+ self.assertTrue(isinstance(profiler, TestCustomProfiler))
+
+ self.sc.show_profiles()
+ self.assertEqual("Custom formatting", profiler.result)
+
+ def do_computation(self):
+ def heavy_foo(x):
+ for i in range(1 << 20):
+ x = 1
+
+ rdd = self.sc.parallelize(range(100))
+ rdd.foreach(heavy_foo)
+
class ExamplePointUDT(UserDefinedType):
"""
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7e5343c973..8a93c320ec 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,8 +23,6 @@ import sys
import time
import socket
import traceback
-import cProfile
-import pstats
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
@@ -90,19 +88,15 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
- (func, stats, deserializer, serializer) = command
+ (func, profiler, deserializer, serializer) = command
init_time = time.time()
def process():
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)
- if stats:
- p = cProfile.Profile()
- p.runcall(process)
- st = pstats.Stats(p)
- st.stream = None # make it picklable
- stats.add(st.strip_dirs())
+ if profiler:
+ profiler.profile(process)
else:
process()
except Exception:
diff --git a/python/run-tests b/python/run-tests
index 9ee19ed6e6..53c34557d9 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -57,6 +57,7 @@ function run_core_tests() {
PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
+ run_test "pyspark/profiler.py"
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
}