aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py52
1 files changed, 50 insertions, 2 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 491e445a21..a01bd8d415 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -32,6 +32,7 @@ import time
import zipfile
import random
import threading
+import hashlib
if sys.version_info[:2] <= (2, 6):
try:
@@ -47,7 +48,7 @@ from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
- CloudPickleSerializer
+ CloudPickleSerializer, SizeLimitedStream, CompressedSerializer, LargeObjectSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
@@ -236,6 +237,27 @@ class SerializationTestCase(unittest.TestCase):
self.assertTrue("exit" in foo.func_code.co_names)
ser.dumps(foo)
+ def _test_serializer(self, ser):
+ from StringIO import StringIO
+ io = StringIO()
+ ser.dump_stream(["abc", u"123", range(5)], io)
+ io.seek(0)
+ self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
+ size = io.tell()
+ ser.dump_stream(range(1000), io)
+ io.seek(0)
+ first = SizeLimitedStream(io, size)
+ self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(first)))
+ self.assertEqual(range(1000), list(ser.load_stream(io)))
+
+ def test_compressed_serializer(self):
+ ser = CompressedSerializer(PickleSerializer())
+ self._test_serializer(ser)
+
+ def test_large_object_serializer(self):
+ ser = LargeObjectSerializer()
+ self._test_serializer(ser)
+
class PySparkTestCase(unittest.TestCase):
@@ -440,7 +462,7 @@ class RDDTests(ReusedPySparkTestCase):
subset = data.takeSample(False, 10)
self.assertEqual(len(subset), 10)
- def testAggregateByKey(self):
+ def test_aggregate_by_key(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
def seqOp(x, y):
@@ -478,6 +500,32 @@ class RDDTests(ReusedPySparkTestCase):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)
+ def test_multiple_broadcasts(self):
+ N = 1 << 21
+ b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
+ r = range(1 << 15)
+ random.shuffle(r)
+ s = str(r)
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
+ random.shuffle(r)
+ s = str(r)
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]