aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-10-06 14:07:53 -0700
committerJosh Rosen <joshrosen@apache.org>2014-10-06 14:07:53 -0700
commit4f01265f7d62e070ba42c251255e385644c1b16c (patch)
treee6dbec031ebe0653ab232ac613548289c720eb48 /python/pyspark/tests.py
parent20ea54cc7a5176ebc63bfa9393a9bf84619bfc66 (diff)
downloadspark-4f01265f7d62e070ba42c251255e385644c1b16c.tar.gz
spark-4f01265f7d62e070ba42c251255e385644c1b16c.tar.bz2
spark-4f01265f7d62e070ba42c251255e385644c1b16c.zip
[SPARK-3786] [PySpark] speedup tests
This patch try to speed up tests of PySpark, re-use the SparkContext in tests.py and mllib/tests.py to reduce the overhead of create SparkContext, remove some test cases, which did not make sense. It also improve the performance of some cases, such as MergerTests and SortTests. before this patch: real 21m27.320s user 4m42.967s sys 0m17.343s after this patch: real 9m47.541s user 2m12.947s sys 0m14.543s It almost cut the time by half. Author: Davies Liu <davies.liu@gmail.com> Closes #2646 from davies/tests and squashes the following commits: c54de60 [Davies Liu] revert change about memory limit 6a2a4b0 [Davies Liu] refactor of tests, speedup 100%
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py92
1 files changed, 42 insertions, 50 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6fb6bc998c..7f05d48ade 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -67,10 +67,10 @@ except:
SPARK_HOME = os.environ["SPARK_HOME"]
-class TestMerger(unittest.TestCase):
+class MergerTests(unittest.TestCase):
def setUp(self):
- self.N = 1 << 16
+ self.N = 1 << 14
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.agg = Aggregator(lambda x: [x],
@@ -115,7 +115,7 @@ class TestMerger(unittest.TestCase):
sum(xrange(self.N)) * 3)
def test_huge_dataset(self):
- m = ExternalMerger(self.agg, 10)
+ m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
@@ -123,7 +123,7 @@ class TestMerger(unittest.TestCase):
m._cleanup()
-class TestSorter(unittest.TestCase):
+class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
l = range(1024)
random.shuffle(l)
@@ -244,16 +244,25 @@ class PySparkTestCase(unittest.TestCase):
sys.path = self._old_sys_path
-class TestCheckpoint(PySparkTestCase):
+class ReusedPySparkTestCase(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sc.stop()
+
+
+class CheckpointTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
- PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
@@ -288,7 +297,7 @@ class TestCheckpoint(PySparkTestCase):
self.assertEquals([1, 2, 3, 4], recovered.collect())
-class TestAddFile(PySparkTestCase):
+class AddFileTests(PySparkTestCase):
def test_add_py_file(self):
# To ensure that we're actually testing addPyFile's effects, check that
@@ -354,7 +363,7 @@ class TestAddFile(PySparkTestCase):
self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
-class TestRDDFunctions(PySparkTestCase):
+class RDDTests(ReusedPySparkTestCase):
def test_id(self):
rdd = self.sc.parallelize(range(10))
@@ -365,12 +374,6 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEqual(id + 1, id2)
self.assertEqual(id2, rdd2.id())
- def test_failed_sparkcontext_creation(self):
- # Regression test for SPARK-1550
- self.sc.stop()
- self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
- self.sc = SparkContext("local")
-
def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
@@ -636,7 +639,7 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEquals(result.count(), 3)
-class TestProfiler(PySparkTestCase):
+class ProfilerTests(PySparkTestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
@@ -666,10 +669,9 @@ class TestProfiler(PySparkTestCase):
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
-class TestSQL(PySparkTestCase):
+class SQLTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.sqlCtx = SQLContext(self.sc)
def test_udf(self):
@@ -754,27 +756,19 @@ class TestSQL(PySparkTestCase):
self.assertEqual("2", row.d)
-class TestIO(PySparkTestCase):
-
- def test_stdout_redirection(self):
- import subprocess
-
- def func(x):
- subprocess.check_call('ls', shell=True)
- self.sc.parallelize([1]).foreach(func)
+class InputFormatTests(ReusedPySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc)
-class TestInputFormat(PySparkTestCase):
-
- def setUp(self):
- PySparkTestCase.setUp(self)
- self.tempdir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(self.tempdir.name)
- self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc)
-
- def tearDown(self):
- PySparkTestCase.tearDown(self)
- shutil.rmtree(self.tempdir.name)
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name)
def test_sequencefiles(self):
basepath = self.tempdir.name
@@ -954,15 +948,13 @@ class TestInputFormat(PySparkTestCase):
self.assertEqual(maps, em)
-class TestOutputFormat(PySparkTestCase):
+class OutputFormatTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)
def tearDown(self):
- PySparkTestCase.tearDown(self)
shutil.rmtree(self.tempdir.name, ignore_errors=True)
def test_sequencefiles(self):
@@ -1243,8 +1235,7 @@ class TestOutputFormat(PySparkTestCase):
basepath + "/malformed/sequence"))
-class TestDaemon(unittest.TestCase):
-
+class DaemonTests(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
sock = socket(AF_INET, SOCK_STREAM)
@@ -1290,7 +1281,7 @@ class TestDaemon(unittest.TestCase):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
-class TestWorker(PySparkTestCase):
+class WorkerTests(PySparkTestCase):
def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
@@ -1342,11 +1333,6 @@ class TestWorker(PySparkTestCase):
rdd = self.sc.parallelize(range(100), 1)
self.assertEqual(100, rdd.map(str).count())
- def test_fd_leak(self):
- N = 1100 # fd limit is 1024 by default
- rdd = self.sc.parallelize(range(N), N)
- self.assertEquals(N, rdd.count())
-
def test_after_exception(self):
def raise_exception(_):
raise Exception()
@@ -1379,7 +1365,7 @@ class TestWorker(PySparkTestCase):
self.assertEqual(sum(range(100)), acc1.value)
-class TestSparkSubmit(unittest.TestCase):
+class SparkSubmitTests(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()
@@ -1492,6 +1478,8 @@ class TestSparkSubmit(unittest.TestCase):
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(foo).collect()
""")
+ # this will fail if you have different spark.executor.memory
+ # in conf/spark-defaults.conf
proc = subprocess.Popen(
[self.sparkSubmit, "--master", "local-cluster[1,1,512]", script],
stdout=subprocess.PIPE)
@@ -1500,7 +1488,11 @@ class TestSparkSubmit(unittest.TestCase):
self.assertIn("[2, 4, 6]", out)
-class ContextStopTests(unittest.TestCase):
+class ContextTests(unittest.TestCase):
+
+ def test_failed_sparkcontext_creation(self):
+ # Regression test for SPARK-1550
+ self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
def test_stop(self):
sc = SparkContext()