aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorAaron Staple <aaron.staple@gmail.com>2014-09-15 19:28:17 -0700
committerJosh Rosen <joshrosen@apache.org>2014-09-15 19:28:17 -0700
commit60050f42885582a699fc7a6fa0529964162bb8a3 (patch)
tree4014b35f39ee99dcd3d26b288dba0e7555a97ad6 /python
parentda33acb8b681eca5e787d546fe922af76a151398 (diff)
downloadspark-60050f42885582a699fc7a6fa0529964162bb8a3.tar.gz
spark-60050f42885582a699fc7a6fa0529964162bb8a3.tar.bz2
spark-60050f42885582a699fc7a6fa0529964162bb8a3.zip
[SPARK-1087] Move python traceback utilities into new traceback_utils.py file.
Also made some cosmetic cleanups. Author: Aaron Staple <aaron.staple@gmail.com> Closes #2385 from staple/SPARK-1087 and squashes the following commits: 7b3bb13 [Aaron Staple] Address review comments, cosmetic cleanups. 10ba6e1 [Aaron Staple] [SPARK-1087] Move python traceback utilities into new traceback_utils.py file.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/context.py8
-rw-r--r--python/pyspark/rdd.py58
-rw-r--r--python/pyspark/traceback_utils.py78
3 files changed, 83 insertions, 61 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index ea28e8cd8c..a33aae87f6 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -20,7 +20,6 @@ import shutil
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
-from collections import namedtuple
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@@ -33,6 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deseria
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
+from pyspark.traceback_utils import CallSite, first_spark_call
from py4j.java_collections import ListConverter
@@ -99,11 +99,7 @@ class SparkContext(object):
...
ValueError:...
"""
- if rdd._extract_concise_traceback() is not None:
- self._callsite = rdd._extract_concise_traceback()
- else:
- tempNamedTuple = namedtuple("Callsite", "function file linenum")
- self._callsite = tempNamedTuple(function=None, file=None, linenum=None)
+ self._callsite = first_spark_call() or CallSite(None, None, None)
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 6ad5ab2a2d..21f182b0ff 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,13 +18,11 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
-from collections import namedtuple
from itertools import chain, ifilter, imap
import operator
import os
import sys
import shlex
-import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
@@ -45,6 +43,7 @@ from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
get_used_memory, ExternalSorter
+from pyspark.traceback_utils import SCCallSiteSync
from py4j.java_collections import ListConverter, MapConverter
@@ -81,57 +80,6 @@ def portable_hash(x):
return hash(x)
-def _extract_concise_traceback():
- """
- This function returns the traceback info for a callsite, returns a dict
- with function name, file name and line number
- """
- tb = traceback.extract_stack()
- callsite = namedtuple("Callsite", "function file linenum")
- if len(tb) == 0:
- return None
- file, line, module, what = tb[len(tb) - 1]
- sparkpath = os.path.dirname(file)
- first_spark_frame = len(tb) - 1
- for i in range(0, len(tb)):
- file, line, fun, what = tb[i]
- if file.startswith(sparkpath):
- first_spark_frame = i
- break
- if first_spark_frame == 0:
- file, line, fun, what = tb[0]
- return callsite(function=fun, file=file, linenum=line)
- sfile, sline, sfun, swhat = tb[first_spark_frame]
- ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
- return callsite(function=sfun, file=ufile, linenum=uline)
-
-_spark_stack_depth = 0
-
-
-class _JavaStackTrace(object):
-
- def __init__(self, sc):
- tb = _extract_concise_traceback()
- if tb is not None:
- self._traceback = "%s at %s:%s" % (
- tb.function, tb.file, tb.linenum)
- else:
- self._traceback = "Error! Could not extract traceback info"
- self._context = sc
-
- def __enter__(self):
- global _spark_stack_depth
- if _spark_stack_depth == 0:
- self._context._jsc.setCallSite(self._traceback)
- _spark_stack_depth += 1
-
- def __exit__(self, type, value, tb):
- global _spark_stack_depth
- _spark_stack_depth -= 1
- if _spark_stack_depth == 0:
- self._context._jsc.setCallSite(None)
-
-
class BoundedFloat(float):
"""
Bounded value is generated by approximate job, with confidence and low
@@ -704,7 +652,7 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- with _JavaStackTrace(self.context) as st:
+ with SCCallSiteSync(self.context) as css:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))
@@ -1515,7 +1463,7 @@ class RDD(object):
keyed = self.mapPartitionsWithIndex(add_shuffle_key)
keyed._bypass_serializer = True
- with _JavaStackTrace(self.context) as st:
+ with SCCallSiteSync(self.context) as css:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
diff --git a/python/pyspark/traceback_utils.py b/python/pyspark/traceback_utils.py
new file mode 100644
index 0000000000..bb8646df2b
--- /dev/null
+++ b/python/pyspark/traceback_utils.py
@@ -0,0 +1,78 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import namedtuple
+import os
+import traceback
+
+
+CallSite = namedtuple("CallSite", "function file linenum")
+
+
+def first_spark_call():
+ """
+ Return a CallSite representing the first Spark call in the current call stack.
+ """
+ tb = traceback.extract_stack()
+ if len(tb) == 0:
+ return None
+ file, line, module, what = tb[len(tb) - 1]
+ sparkpath = os.path.dirname(file)
+ first_spark_frame = len(tb) - 1
+ for i in range(0, len(tb)):
+ file, line, fun, what = tb[i]
+ if file.startswith(sparkpath):
+ first_spark_frame = i
+ break
+ if first_spark_frame == 0:
+ file, line, fun, what = tb[0]
+ return CallSite(function=fun, file=file, linenum=line)
+ sfile, sline, sfun, swhat = tb[first_spark_frame]
+ ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
+ return CallSite(function=sfun, file=ufile, linenum=uline)
+
+
+class SCCallSiteSync(object):
+ """
+ Helper for setting the spark context call site.
+
+ Example usage:
+ from pyspark.context import SCCallSiteSync
+ with SCCallSiteSync(<relevant SparkContext>) as css:
+ <a Spark call>
+ """
+
+ _spark_stack_depth = 0
+
+ def __init__(self, sc):
+ call_site = first_spark_call()
+ if call_site is not None:
+ self._call_site = "%s at %s:%s" % (
+ call_site.function, call_site.file, call_site.linenum)
+ else:
+ self._call_site = "Error! Could not extract traceback info"
+ self._context = sc
+
+ def __enter__(self):
+ if SCCallSiteSync._spark_stack_depth == 0:
+ self._context._jsc.setCallSite(self._call_site)
+ SCCallSiteSync._spark_stack_depth += 1
+
+ def __exit__(self, type, value, tb):
+ SCCallSiteSync._spark_stack_depth -= 1
+ if SCCallSiteSync._spark_stack_depth == 0:
+ self._context._jsc.setCallSite(None)