aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r--python/pyspark/worker.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 61b8a74d06..252176ac65 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,16 +23,14 @@ import sys
import time
import socket
import traceback
-# CloudPickler needs to be imported so that depicklers are registered using the
-# copy_reg module.
+
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
-from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
CompressedSerializer
-
+from pyspark import shuffle
pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()
@@ -52,6 +50,11 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
return
+ # initialize global state
+ shuffle.MemoryBytesSpilled = 0
+ shuffle.DiskBytesSpilled = 0
+ _accumulatorRegistry.clear()
+
# fetch name of workdir
spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
@@ -97,6 +100,9 @@ def main(infile, outfile):
exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
+ write_long(shuffle.MemoryBytesSpilled, outfile)
+ write_long(shuffle.DiskBytesSpilled, outfile)
+
# Mark the beginning of the accumulators section of the output
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
write_int(len(_accumulatorRegistry), outfile)