aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 9c5ecd0bb0..a92abbf371 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -34,6 +34,7 @@ import zipfile
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int
+from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False
try:
@@ -47,6 +48,62 @@ except:
SPARK_HOME = os.environ["SPARK_HOME"]
+class TestMerger(unittest.TestCase):
+
+ def setUp(self):
+ self.N = 1 << 16
+ self.l = [i for i in xrange(self.N)]
+ self.data = zip(self.l, self.l)
+ self.agg = Aggregator(lambda x: [x],
+ lambda x, y: x.append(y) or x,
+ lambda x, y: x.extend(y) or x)
+
+ def test_in_memory(self):
+ m = InMemoryMerger(self.agg)
+ m.mergeValues(self.data)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ m = InMemoryMerger(self.agg)
+ m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ def test_small_dataset(self):
+ m = ExternalMerger(self.agg, 1000)
+ m.mergeValues(self.data)
+ self.assertEqual(m.spills, 0)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ m = ExternalMerger(self.agg, 1000)
+ m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
+ self.assertEqual(m.spills, 0)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ def test_medium_dataset(self):
+ m = ExternalMerger(self.agg, 10)
+ m.mergeValues(self.data)
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ m = ExternalMerger(self.agg, 10)
+ m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3))
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)) * 3)
+
+ def test_huge_dataset(self):
+ m = ExternalMerger(self.agg, 10)
+ m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
+ self.N * 10)
+ m._cleanup()
+
+
class PySparkTestCase(unittest.TestCase):
def setUp(self):