aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py50
1 files changed, 46 insertions, 4 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index dd8d3b1c53..0bd5d20f78 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,6 +31,7 @@ import tempfile
import time
import zipfile
import random
+import itertools
import threading
import hashlib
@@ -76,7 +77,7 @@ SPARK_HOME = os.environ["SPARK_HOME"]
class MergerTests(unittest.TestCase):
def setUp(self):
- self.N = 1 << 14
+ self.N = 1 << 12
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.agg = Aggregator(lambda x: [x],
@@ -108,7 +109,7 @@ class MergerTests(unittest.TestCase):
sum(xrange(self.N)))
def test_medium_dataset(self):
- m = ExternalMerger(self.agg, 10)
+ m = ExternalMerger(self.agg, 30)
m.mergeValues(self.data)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
@@ -124,10 +125,36 @@ class MergerTests(unittest.TestCase):
m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
- self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
+ self.assertEqual(sum(len(v) for k, v in m.iteritems()),
self.N * 10)
m._cleanup()
+ def test_group_by_key(self):
+
+ def gen_data(N, step):
+ for i in range(1, N + 1, step):
+ for j in range(i):
+ yield (i, [j])
+
+ def gen_gs(N, step=1):
+ return shuffle.GroupByKey(gen_data(N, step))
+
+ self.assertEqual(1, len(list(gen_gs(1))))
+ self.assertEqual(2, len(list(gen_gs(2))))
+ self.assertEqual(100, len(list(gen_gs(100))))
+ self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)])
+ self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100)))
+
+ for k, vs in gen_gs(50002, 10000):
+ self.assertEqual(k, len(vs))
+ self.assertEqual(range(k), list(vs))
+
+ ser = PickleSerializer()
+ l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
+ for k, vs in l:
+ self.assertEqual(k, len(vs))
+ self.assertEqual(range(k), list(vs))
+
class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
@@ -702,6 +729,21 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEquals(result.getNumPartitions(), 5)
self.assertEquals(result.count(), 3)
+ def test_external_group_by_key(self):
+ self.sc._conf.set("spark.python.worker.memory", "5m")
+ N = 200001
+ kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
+ gkv = kv.groupByKey().cache()
+ self.assertEqual(3, gkv.count())
+ filtered = gkv.filter(lambda (k, vs): k == 1)
+ self.assertEqual(1, filtered.count())
+ self.assertEqual([(1, N/3)], filtered.mapValues(len).collect())
+ self.assertEqual([(N/3, N/3)],
+ filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
+ result = filtered.collect()[0][1]
+ self.assertEqual(N/3, len(result))
+ self.assertTrue(isinstance(result.data, shuffle.ExternalList))
+
def test_sort_on_empty_rdd(self):
self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
@@ -752,9 +794,9 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
- self.sc.setJobGroup("test1", "test", True)
tracker = self.sc.statusTracker()
+ self.sc.setJobGroup("test1", "test", True)
d = sorted(parted.join(parted).collect())
self.assertEqual(10, len(d))
self.assertEqual((0, (0, 0)), d[0])