From 60050f42885582a699fc7a6fa0529964162bb8a3 Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Mon, 15 Sep 2014 19:28:17 -0700 Subject: [SPARK-1087] Move python traceback utilities into new traceback_utils.py file. Also made some cosmetic cleanups. Author: Aaron Staple Closes #2385 from staple/SPARK-1087 and squashes the following commits: 7b3bb13 [Aaron Staple] Address review comments, cosmetic cleanups. 10ba6e1 [Aaron Staple] [SPARK-1087] Move python traceback utilities into new traceback_utils.py file. --- python/pyspark/context.py | 8 +--- python/pyspark/rdd.py | 58 ++--------------------------- python/pyspark/traceback_utils.py | 78 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 61 deletions(-) create mode 100644 python/pyspark/traceback_utils.py (limited to 'python') diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ea28e8cd8c..a33aae87f6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import shutil import sys from threading import Lock from tempfile import NamedTemporaryFile -from collections import namedtuple from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -33,6 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deseria from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD +from pyspark.traceback_utils import CallSite, first_spark_call from py4j.java_collections import ListConverter @@ -99,11 +99,7 @@ class SparkContext(object): ... ValueError:... """ - if rdd._extract_concise_traceback() is not None: - self._callsite = rdd._extract_concise_traceback() - else: - tempNamedTuple = namedtuple("Callsite", "function file linenum") - self._callsite = tempNamedTuple(function=None, file=None, linenum=None) + self._callsite = first_spark_call() or CallSite(None, None, None) SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6ad5ab2a2d..21f182b0ff 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,13 +18,11 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from collections import namedtuple from itertools import chain, ifilter, imap import operator import os import sys import shlex -import traceback from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread @@ -45,6 +43,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ get_used_memory, ExternalSorter +from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -81,57 +80,6 @@ def portable_hash(x): return hash(x) -def _extract_concise_traceback(): - """ - This function returns the traceback info for a callsite, returns a dict - with function name, file name and line number - """ - tb = traceback.extract_stack() - callsite = namedtuple("Callsite", "function file linenum") - if len(tb) == 0: - return None - file, line, module, what = tb[len(tb) - 1] - sparkpath = os.path.dirname(file) - first_spark_frame = len(tb) - 1 - for i in range(0, len(tb)): - file, line, fun, what = tb[i] - if file.startswith(sparkpath): - first_spark_frame = i - break - if first_spark_frame == 0: - file, line, fun, what = tb[0] - return callsite(function=fun, file=file, linenum=line) - sfile, sline, sfun, swhat = tb[first_spark_frame] - ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] - return callsite(function=sfun, file=ufile, linenum=uline) - -_spark_stack_depth = 0 - - -class _JavaStackTrace(object): - - def __init__(self, sc): - tb = _extract_concise_traceback() - if tb is not None: - self._traceback = "%s at %s:%s" % ( - tb.function, tb.file, tb.linenum) - else: - self._traceback = "Error! Could not extract traceback info" - self._context = sc - - def __enter__(self): - global _spark_stack_depth - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(self._traceback) - _spark_stack_depth += 1 - - def __exit__(self, type, value, tb): - global _spark_stack_depth - _spark_stack_depth -= 1 - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(None) - - class BoundedFloat(float): """ Bounded value is generated by approximate job, with confidence and low @@ -704,7 +652,7 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - with _JavaStackTrace(self.context) as st: + with SCCallSiteSync(self.context) as css: bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) @@ -1515,7 +1463,7 @@ class RDD(object): keyed = self.mapPartitionsWithIndex(add_shuffle_key) keyed._bypass_serializer = True - with _JavaStackTrace(self.context) as st: + with SCCallSiteSync(self.context) as css: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, diff --git a/python/pyspark/traceback_utils.py b/python/pyspark/traceback_utils.py new file mode 100644 index 0000000000..bb8646df2b --- /dev/null +++ b/python/pyspark/traceback_utils.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple +import os +import traceback + + +CallSite = namedtuple("CallSite", "function file linenum") + + +def first_spark_call(): + """ + Return a CallSite representing the first Spark call in the current call stack. + """ + tb = traceback.extract_stack() + if len(tb) == 0: + return None + file, line, module, what = tb[len(tb) - 1] + sparkpath = os.path.dirname(file) + first_spark_frame = len(tb) - 1 + for i in range(0, len(tb)): + file, line, fun, what = tb[i] + if file.startswith(sparkpath): + first_spark_frame = i + break + if first_spark_frame == 0: + file, line, fun, what = tb[0] + return CallSite(function=fun, file=file, linenum=line) + sfile, sline, sfun, swhat = tb[first_spark_frame] + ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] + return CallSite(function=sfun, file=ufile, linenum=uline) + + +class SCCallSiteSync(object): + """ + Helper for setting the spark context call site. + + Example usage: + from pyspark.context import SCCallSiteSync + with SCCallSiteSync() as css: + + """ + + _spark_stack_depth = 0 + + def __init__(self, sc): + call_site = first_spark_call() + if call_site is not None: + self._call_site = "%s at %s:%s" % ( + call_site.function, call_site.file, call_site.linenum) + else: + self._call_site = "Error! Could not extract traceback info" + self._context = sc + + def __enter__(self): + if SCCallSiteSync._spark_stack_depth == 0: + self._context._jsc.setCallSite(self._call_site) + SCCallSiteSync._spark_stack_depth += 1 + + def __exit__(self, type, value, tb): + SCCallSiteSync._spark_stack_depth -= 1 + if SCCallSiteSync._spark_stack_depth == 0: + self._context._jsc.setCallSite(None) -- cgit v1.2.3