diff options
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r-- | python/pyspark/tests.py | 52 |
1 files changed, 50 insertions, 2 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 491e445a21..a01bd8d415 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -32,6 +32,7 @@ import time import zipfile import random import threading +import hashlib if sys.version_info[:2] <= (2, 6): try: @@ -47,7 +48,7 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer + CloudPickleSerializer, SizeLimitedStream, CompressedSerializer, LargeObjectSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType @@ -236,6 +237,27 @@ class SerializationTestCase(unittest.TestCase): self.assertTrue("exit" in foo.func_code.co_names) ser.dumps(foo) + def _test_serializer(self, ser): + from StringIO import StringIO + io = StringIO() + ser.dump_stream(["abc", u"123", range(5)], io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) + size = io.tell() + ser.dump_stream(range(1000), io) + io.seek(0) + first = SizeLimitedStream(io, size) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(first))) + self.assertEqual(range(1000), list(ser.load_stream(io))) + + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) + self._test_serializer(ser) + + def test_large_object_serializer(self): + ser = LargeObjectSerializer() + self._test_serializer(ser) + class PySparkTestCase(unittest.TestCase): @@ -440,7 +462,7 @@ class RDDTests(ReusedPySparkTestCase): subset = data.takeSample(False, 10) self.assertEqual(len(subset), 10) - def testAggregateByKey(self): + def test_aggregate_by_key(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) def seqOp(x, y): @@ -478,6 +500,32 @@ class RDDTests(ReusedPySparkTestCase): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_multiple_broadcasts(self): + N = 1 << 21 + b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM + r = range(1 << 15) + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] |