aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala4
-rw-r--r--python/pyspark/shuffle.py19
-rw-r--r--python/pyspark/tests.py15
-rw-r--r--python/pyspark/worker.py14
4 files changed, 38 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ca8eef5f99..d5002fa029 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -124,6 +124,10 @@ private[spark] class PythonRDD(
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
+ val memoryBytesSpilled = stream.readLong()
+ val diskBytesSpilled = stream.readLong()
+ context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled += diskBytesSpilled
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 49829f5280..ce597cbe91 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -68,6 +68,11 @@ def _get_local_dirs(sub):
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
+# global stats
+MemoryBytesSpilled = 0L
+DiskBytesSpilled = 0L
+
+
class Aggregator(object):
"""
@@ -313,10 +318,12 @@ class ExternalMerger(Merger):
It will dump the data in batch for better performance.
"""
+ global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)
+ used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
@@ -334,6 +341,7 @@ class ExternalMerger(Merger):
self.serializer.dump_stream([(k, v)], streams[h])
for s in streams:
+ DiskBytesSpilled += s.tell()
s.close()
self.data.clear()
@@ -346,9 +354,11 @@ class ExternalMerger(Merger):
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
+ DiskBytesSpilled += os.path.getsize(p)
self.spills += 1
gc.collect() # release the memory as much as possible
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
def iteritems(self):
""" Return all merged items as iterator """
@@ -462,7 +472,6 @@ class ExternalSorter(object):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
- self._spilled_bytes = 0
def _get_path(self, n):
""" Choose one directory for spill by number n """
@@ -476,6 +485,7 @@ class ExternalSorter(object):
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
+ global MemoryBytesSpilled, DiskBytesSpilled
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
@@ -486,15 +496,18 @@ class ExternalSorter(object):
if len(chunk) < batch:
break
- if get_used_memory() > self.memory_limit:
+ used_memory = get_used_memory()
+ if used_memory > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
- self._spilled_bytes += os.path.getsize(path)
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []
+ gc.collect()
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+ DiskBytesSpilled += os.path.getsize(path)
elif not chunks:
batch = min(batch * 2, 10000)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 747cd1767d..f3309a20fc 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -46,6 +46,7 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer,
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType
+from pyspark import shuffle
_have_scipy = False
_have_numpy = False
@@ -138,17 +139,17 @@ class TestSorter(unittest.TestCase):
random.shuffle(l)
sorter = ExternalSorter(1)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
- self.assertGreater(sorter._spilled_bytes, 0)
- last = sorter._spilled_bytes
+ self.assertGreater(shuffle.DiskBytesSpilled, 0)
+ last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
- self.assertGreater(sorter._spilled_bytes, last)
- last = sorter._spilled_bytes
+ self.assertGreater(shuffle.DiskBytesSpilled, last)
+ last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
- self.assertGreater(sorter._spilled_bytes, last)
- last = sorter._spilled_bytes
+ self.assertGreater(shuffle.DiskBytesSpilled, last)
+ last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
- self.assertGreater(sorter._spilled_bytes, last)
+ self.assertGreater(shuffle.DiskBytesSpilled, last)
def test_external_sort_in_rdd(self):
conf = SparkConf().set("spark.python.worker.memory", "1m")
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 61b8a74d06..252176ac65 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,16 +23,14 @@ import sys
import time
import socket
import traceback
-# CloudPickler needs to be imported so that depicklers are registered using the
-# copy_reg module.
+
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
-from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
CompressedSerializer
-
+from pyspark import shuffle
pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()
@@ -52,6 +50,11 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
return
+ # initialize global state
+ shuffle.MemoryBytesSpilled = 0
+ shuffle.DiskBytesSpilled = 0
+ _accumulatorRegistry.clear()
+
# fetch name of workdir
spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
@@ -97,6 +100,9 @@ def main(infile, outfile):
exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
+ write_long(shuffle.MemoryBytesSpilled, outfile)
+ write_long(shuffle.DiskBytesSpilled, outfile)
+
# Mark the beginning of the accumulators section of the output
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
write_int(len(_accumulatorRegistry), outfile)