aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-10-06 14:07:53 -0700
committerJosh Rosen <joshrosen@apache.org>2014-10-06 14:07:53 -0700
commit4f01265f7d62e070ba42c251255e385644c1b16c (patch)
treee6dbec031ebe0653ab232ac613548289c720eb48 /python
parent20ea54cc7a5176ebc63bfa9393a9bf84619bfc66 (diff)
downloadspark-4f01265f7d62e070ba42c251255e385644c1b16c.tar.gz
spark-4f01265f7d62e070ba42c251255e385644c1b16c.tar.bz2
spark-4f01265f7d62e070ba42c251255e385644c1b16c.zip
[SPARK-3786] [PySpark] speedup tests
This patch try to speed up tests of PySpark, re-use the SparkContext in tests.py and mllib/tests.py to reduce the overhead of create SparkContext, remove some test cases, which did not make sense. It also improve the performance of some cases, such as MergerTests and SortTests. before this patch: real 21m27.320s user 4m42.967s sys 0m17.343s after this patch: real 9m47.541s user 2m12.947s sys 0m14.543s It almost cut the time by half. Author: Davies Liu <davies.liu@gmail.com> Closes #2646 from davies/tests and squashes the following commits: c54de60 [Davies Liu] revert change about memory limit 6a2a4b0 [Davies Liu] refactor of tests, speedup 100%
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/tests.py2
-rw-r--r--python/pyspark/shuffle.py5
-rw-r--r--python/pyspark/tests.py92
-rwxr-xr-xpython/run-tests74
4 files changed, 82 insertions, 91 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f72e88ba6e..5c20e100e1 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -32,7 +32,7 @@ else:
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
-from pyspark.tests import PySparkTestCase
+from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
_have_scipy = False
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index ce597cbe91..d57a802e47 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -396,7 +396,6 @@ class ExternalMerger(Merger):
for v in self.data.iteritems():
yield v
self.data.clear()
- gc.collect()
# remove the merged partition
for j in range(self.spills):
@@ -428,7 +427,7 @@ class ExternalMerger(Merger):
subdirs = [os.path.join(d, "parts", str(i))
for d in self.localdirs]
m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
- subdirs, self.scale * self.partitions)
+ subdirs, self.scale * self.partitions, self.partitions)
m.pdata = [{} for _ in range(self.partitions)]
limit = self._next_limit()
@@ -486,7 +485,7 @@ class ExternalSorter(object):
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
- batch = 10
+ batch = 100
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6fb6bc998c..7f05d48ade 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -67,10 +67,10 @@ except:
SPARK_HOME = os.environ["SPARK_HOME"]
-class TestMerger(unittest.TestCase):
+class MergerTests(unittest.TestCase):
def setUp(self):
- self.N = 1 << 16
+ self.N = 1 << 14
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.agg = Aggregator(lambda x: [x],
@@ -115,7 +115,7 @@ class TestMerger(unittest.TestCase):
sum(xrange(self.N)) * 3)
def test_huge_dataset(self):
- m = ExternalMerger(self.agg, 10)
+ m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
@@ -123,7 +123,7 @@ class TestMerger(unittest.TestCase):
m._cleanup()
-class TestSorter(unittest.TestCase):
+class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
l = range(1024)
random.shuffle(l)
@@ -244,16 +244,25 @@ class PySparkTestCase(unittest.TestCase):
sys.path = self._old_sys_path
-class TestCheckpoint(PySparkTestCase):
+class ReusedPySparkTestCase(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sc.stop()
+
+
+class CheckpointTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
- PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
@@ -288,7 +297,7 @@ class TestCheckpoint(PySparkTestCase):
self.assertEquals([1, 2, 3, 4], recovered.collect())
-class TestAddFile(PySparkTestCase):
+class AddFileTests(PySparkTestCase):
def test_add_py_file(self):
# To ensure that we're actually testing addPyFile's effects, check that
@@ -354,7 +363,7 @@ class TestAddFile(PySparkTestCase):
self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
-class TestRDDFunctions(PySparkTestCase):
+class RDDTests(ReusedPySparkTestCase):
def test_id(self):
rdd = self.sc.parallelize(range(10))
@@ -365,12 +374,6 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEqual(id + 1, id2)
self.assertEqual(id2, rdd2.id())
- def test_failed_sparkcontext_creation(self):
- # Regression test for SPARK-1550
- self.sc.stop()
- self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
- self.sc = SparkContext("local")
-
def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
@@ -636,7 +639,7 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEquals(result.count(), 3)
-class TestProfiler(PySparkTestCase):
+class ProfilerTests(PySparkTestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
@@ -666,10 +669,9 @@ class TestProfiler(PySparkTestCase):
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
-class TestSQL(PySparkTestCase):
+class SQLTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.sqlCtx = SQLContext(self.sc)
def test_udf(self):
@@ -754,27 +756,19 @@ class TestSQL(PySparkTestCase):
self.assertEqual("2", row.d)
-class TestIO(PySparkTestCase):
-
- def test_stdout_redirection(self):
- import subprocess
-
- def func(x):
- subprocess.check_call('ls', shell=True)
- self.sc.parallelize([1]).foreach(func)
+class InputFormatTests(ReusedPySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc)
-class TestInputFormat(PySparkTestCase):
-
- def setUp(self):
- PySparkTestCase.setUp(self)
- self.tempdir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(self.tempdir.name)
- self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc)
-
- def tearDown(self):
- PySparkTestCase.tearDown(self)
- shutil.rmtree(self.tempdir.name)
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name)
def test_sequencefiles(self):
basepath = self.tempdir.name
@@ -954,15 +948,13 @@ class TestInputFormat(PySparkTestCase):
self.assertEqual(maps, em)
-class TestOutputFormat(PySparkTestCase):
+class OutputFormatTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)
def tearDown(self):
- PySparkTestCase.tearDown(self)
shutil.rmtree(self.tempdir.name, ignore_errors=True)
def test_sequencefiles(self):
@@ -1243,8 +1235,7 @@ class TestOutputFormat(PySparkTestCase):
basepath + "/malformed/sequence"))
-class TestDaemon(unittest.TestCase):
-
+class DaemonTests(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
sock = socket(AF_INET, SOCK_STREAM)
@@ -1290,7 +1281,7 @@ class TestDaemon(unittest.TestCase):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
-class TestWorker(PySparkTestCase):
+class WorkerTests(PySparkTestCase):
def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
@@ -1342,11 +1333,6 @@ class TestWorker(PySparkTestCase):
rdd = self.sc.parallelize(range(100), 1)
self.assertEqual(100, rdd.map(str).count())
- def test_fd_leak(self):
- N = 1100 # fd limit is 1024 by default
- rdd = self.sc.parallelize(range(N), N)
- self.assertEquals(N, rdd.count())
-
def test_after_exception(self):
def raise_exception(_):
raise Exception()
@@ -1379,7 +1365,7 @@ class TestWorker(PySparkTestCase):
self.assertEqual(sum(range(100)), acc1.value)
-class TestSparkSubmit(unittest.TestCase):
+class SparkSubmitTests(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()
@@ -1492,6 +1478,8 @@ class TestSparkSubmit(unittest.TestCase):
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(foo).collect()
""")
+ # this will fail if you have different spark.executor.memory
+ # in conf/spark-defaults.conf
proc = subprocess.Popen(
[self.sparkSubmit, "--master", "local-cluster[1,1,512]", script],
stdout=subprocess.PIPE)
@@ -1500,7 +1488,11 @@ class TestSparkSubmit(unittest.TestCase):
self.assertIn("[2, 4, 6]", out)
-class ContextStopTests(unittest.TestCase):
+class ContextTests(unittest.TestCase):
+
+ def test_failed_sparkcontext_creation(self):
+ # Regression test for SPARK-1550
+ self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
def test_stop(self):
sc = SparkContext()
diff --git a/python/run-tests b/python/run-tests
index a7ec270c7d..c713861eb7 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -34,7 +34,7 @@ rm -rf metastore warehouse
function run_test() {
echo "Running test: $1"
- SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
+ SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
FAILED=$((PIPESTATUS[0]||$FAILED))
@@ -48,6 +48,37 @@ function run_test() {
fi
}
+function run_core_tests() {
+ echo "Run core tests ..."
+ run_test "pyspark/rdd.py"
+ run_test "pyspark/context.py"
+ run_test "pyspark/conf.py"
+ PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
+ PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
+ PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py"
+ run_test "pyspark/shuffle.py"
+ run_test "pyspark/tests.py"
+}
+
+function run_sql_tests() {
+ echo "Run sql tests ..."
+ run_test "pyspark/sql.py"
+}
+
+function run_mllib_tests() {
+ echo "Run mllib tests ..."
+ run_test "pyspark/mllib/classification.py"
+ run_test "pyspark/mllib/clustering.py"
+ run_test "pyspark/mllib/linalg.py"
+ run_test "pyspark/mllib/random.py"
+ run_test "pyspark/mllib/recommendation.py"
+ run_test "pyspark/mllib/regression.py"
+ run_test "pyspark/mllib/stat.py"
+ run_test "pyspark/mllib/tree.py"
+ run_test "pyspark/mllib/util.py"
+ run_test "pyspark/mllib/tests.py"
+}
+
echo "Running PySpark tests. Output is in python/unit-tests.log."
export PYSPARK_PYTHON="python"
@@ -60,29 +91,9 @@ fi
echo "Testing with Python version:"
$PYSPARK_PYTHON --version
-run_test "pyspark/rdd.py"
-run_test "pyspark/context.py"
-run_test "pyspark/conf.py"
-run_test "pyspark/sql.py"
-# These tests are included in the module-level docs, and so must
-# be handled on a higher level rather than within the python file.
-export PYSPARK_DOC_TEST=1
-run_test "pyspark/broadcast.py"
-run_test "pyspark/accumulators.py"
-run_test "pyspark/serializers.py"
-unset PYSPARK_DOC_TEST
-run_test "pyspark/shuffle.py"
-run_test "pyspark/tests.py"
-run_test "pyspark/mllib/classification.py"
-run_test "pyspark/mllib/clustering.py"
-run_test "pyspark/mllib/linalg.py"
-run_test "pyspark/mllib/random.py"
-run_test "pyspark/mllib/recommendation.py"
-run_test "pyspark/mllib/regression.py"
-run_test "pyspark/mllib/stat.py"
-run_test "pyspark/mllib/tests.py"
-run_test "pyspark/mllib/tree.py"
-run_test "pyspark/mllib/util.py"
+run_core_tests
+run_sql_tests
+run_mllib_tests
# Try to test with PyPy
if [ $(which pypy) ]; then
@@ -90,19 +101,8 @@ if [ $(which pypy) ]; then
echo "Testing with PyPy version:"
$PYSPARK_PYTHON --version
- run_test "pyspark/rdd.py"
- run_test "pyspark/context.py"
- run_test "pyspark/conf.py"
- run_test "pyspark/sql.py"
- # These tests are included in the module-level docs, and so must
- # be handled on a higher level rather than within the python file.
- export PYSPARK_DOC_TEST=1
- run_test "pyspark/broadcast.py"
- run_test "pyspark/accumulators.py"
- run_test "pyspark/serializers.py"
- unset PYSPARK_DOC_TEST
- run_test "pyspark/shuffle.py"
- run_test "pyspark/tests.py"
+ run_core_tests
+ run_sql_tests
fi
if [[ $FAILED == 0 ]]; then