aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py42
1 files changed, 41 insertions, 1 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1db922f513..3e7040eade 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -30,6 +30,7 @@ import sys
import tempfile
import time
import zipfile
+import random
if sys.version_info[:2] <= (2, 6):
import unittest2 as unittest
@@ -37,10 +38,11 @@ else:
import unittest
+from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
-from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
+from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
_have_scipy = False
_have_numpy = False
@@ -117,6 +119,44 @@ class TestMerger(unittest.TestCase):
m._cleanup()
+class TestSorter(unittest.TestCase):
+ def test_in_memory_sort(self):
+ l = range(1024)
+ random.shuffle(l)
+ sorter = ExternalSorter(1024)
+ self.assertEquals(sorted(l), list(sorter.sorted(l)))
+ self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
+ self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
+ self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
+ list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
+
+ def test_external_sort(self):
+ l = range(1024)
+ random.shuffle(l)
+ sorter = ExternalSorter(1)
+ self.assertEquals(sorted(l), list(sorter.sorted(l)))
+ self.assertGreater(sorter._spilled_bytes, 0)
+ last = sorter._spilled_bytes
+ self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
+ self.assertGreater(sorter._spilled_bytes, last)
+ last = sorter._spilled_bytes
+ self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
+ self.assertGreater(sorter._spilled_bytes, last)
+ last = sorter._spilled_bytes
+ self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
+ list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
+ self.assertGreater(sorter._spilled_bytes, last)
+
+ def test_external_sort_in_rdd(self):
+ conf = SparkConf().set("spark.python.worker.memory", "1m")
+ sc = SparkContext(conf=conf)
+ l = range(10240)
+ random.shuffle(l)
+ rdd = sc.parallelize(l, 10)
+ self.assertEquals(sorted(l), rdd.sortBy(lambda x: x).collect())
+ sc.stop()
+
+
class SerializationTestCase(unittest.TestCase):
def test_namedtuple(self):