aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py92
1 files changed, 42 insertions, 50 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6fb6bc998c..7f05d48ade 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -67,10 +67,10 @@ except:
SPARK_HOME = os.environ["SPARK_HOME"]
-class TestMerger(unittest.TestCase):
+class MergerTests(unittest.TestCase):
def setUp(self):
- self.N = 1 << 16
+ self.N = 1 << 14
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.agg = Aggregator(lambda x: [x],
@@ -115,7 +115,7 @@ class TestMerger(unittest.TestCase):
sum(xrange(self.N)) * 3)
def test_huge_dataset(self):
- m = ExternalMerger(self.agg, 10)
+ m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
@@ -123,7 +123,7 @@ class TestMerger(unittest.TestCase):
m._cleanup()
-class TestSorter(unittest.TestCase):
+class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
l = range(1024)
random.shuffle(l)
@@ -244,16 +244,25 @@ class PySparkTestCase(unittest.TestCase):
sys.path = self._old_sys_path
-class TestCheckpoint(PySparkTestCase):
+class ReusedPySparkTestCase(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sc.stop()
+
+
+class CheckpointTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
- PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
@@ -288,7 +297,7 @@ class TestCheckpoint(PySparkTestCase):
self.assertEquals([1, 2, 3, 4], recovered.collect())
-class TestAddFile(PySparkTestCase):
+class AddFileTests(PySparkTestCase):
def test_add_py_file(self):
# To ensure that we're actually testing addPyFile's effects, check that
@@ -354,7 +363,7 @@ class TestAddFile(PySparkTestCase):
self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
-class TestRDDFunctions(PySparkTestCase):
+class RDDTests(ReusedPySparkTestCase):
def test_id(self):
rdd = self.sc.parallelize(range(10))
@@ -365,12 +374,6 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEqual(id + 1, id2)
self.assertEqual(id2, rdd2.id())
- def test_failed_sparkcontext_creation(self):
- # Regression test for SPARK-1550
- self.sc.stop()
- self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
- self.sc = SparkContext("local")
-
def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
@@ -636,7 +639,7 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEquals(result.count(), 3)
-class TestProfiler(PySparkTestCase):
+class ProfilerTests(PySparkTestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
@@ -666,10 +669,9 @@ class TestProfiler(PySparkTestCase):
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
-class TestSQL(PySparkTestCase):
+class SQLTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.sqlCtx = SQLContext(self.sc)
def test_udf(self):
@@ -754,27 +756,19 @@ class TestSQL(PySparkTestCase):
self.assertEqual("2", row.d)
-class TestIO(PySparkTestCase):
-
- def test_stdout_redirection(self):
- import subprocess
-
- def func(x):
- subprocess.check_call('ls', shell=True)
- self.sc.parallelize([1]).foreach(func)
+class InputFormatTests(ReusedPySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc)
-class TestInputFormat(PySparkTestCase):
-
- def setUp(self):
- PySparkTestCase.setUp(self)
- self.tempdir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(self.tempdir.name)
- self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc)
-
- def tearDown(self):
- PySparkTestCase.tearDown(self)
- shutil.rmtree(self.tempdir.name)
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name)
def test_sequencefiles(self):
basepath = self.tempdir.name
@@ -954,15 +948,13 @@ class TestInputFormat(PySparkTestCase):
self.assertEqual(maps, em)
-class TestOutputFormat(PySparkTestCase):
+class OutputFormatTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)
def tearDown(self):
- PySparkTestCase.tearDown(self)
shutil.rmtree(self.tempdir.name, ignore_errors=True)
def test_sequencefiles(self):
@@ -1243,8 +1235,7 @@ class TestOutputFormat(PySparkTestCase):
basepath + "/malformed/sequence"))
-class TestDaemon(unittest.TestCase):
-
+class DaemonTests(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
sock = socket(AF_INET, SOCK_STREAM)
@@ -1290,7 +1281,7 @@ class TestDaemon(unittest.TestCase):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
-class TestWorker(PySparkTestCase):
+class WorkerTests(PySparkTestCase):
def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
@@ -1342,11 +1333,6 @@ class TestWorker(PySparkTestCase):
rdd = self.sc.parallelize(range(100), 1)
self.assertEqual(100, rdd.map(str).count())
- def test_fd_leak(self):
- N = 1100 # fd limit is 1024 by default
- rdd = self.sc.parallelize(range(N), N)
- self.assertEquals(N, rdd.count())
-
def test_after_exception(self):
def raise_exception(_):
raise Exception()
@@ -1379,7 +1365,7 @@ class TestWorker(PySparkTestCase):
self.assertEqual(sum(range(100)), acc1.value)
-class TestSparkSubmit(unittest.TestCase):
+class SparkSubmitTests(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()
@@ -1492,6 +1478,8 @@ class TestSparkSubmit(unittest.TestCase):
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(foo).collect()
""")
+ # this will fail if you have different spark.executor.memory
+ # in conf/spark-defaults.conf
proc = subprocess.Popen(
[self.sparkSubmit, "--master", "local-cluster[1,1,512]", script],
stdout=subprocess.PIPE)
@@ -1500,7 +1488,11 @@ class TestSparkSubmit(unittest.TestCase):
self.assertIn("[2, 4, 6]", out)
-class ContextStopTests(unittest.TestCase):
+class ContextTests(unittest.TestCase):
+
+ def test_failed_sparkcontext_creation(self):
+ # Regression test for SPARK-1550
+ self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
def test_stop(self):
sc = SparkContext()