aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xdev/run-tests2
-rwxr-xr-xdev/run-tests.py24
-rw-r--r--dev/sparktestsupport/shellutils.py1
-rw-r--r--python/pyspark/java_gateway.py2
-rwxr-xr-xpython/run-tests.py97
5 files changed, 101 insertions, 25 deletions
diff --git a/dev/run-tests b/dev/run-tests
index a00d9f0c27..257d1e8d50 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -20,4 +20,4 @@
FWDIR="$(cd "`dirname $0`"/..; pwd)"
cd "$FWDIR"
-exec python -u ./dev/run-tests.py
+exec python -u ./dev/run-tests.py "$@"
diff --git a/dev/run-tests.py b/dev/run-tests.py
index e5c897b94d..4596e07014 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import itertools
+from optparse import OptionParser
import os
import re
import sys
@@ -360,12 +361,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules):
run_scala_tests_sbt(test_modules, test_profiles)
-def run_python_tests(test_modules):
+def run_python_tests(test_modules, parallelism):
set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS")
command = [os.path.join(SPARK_HOME, "python", "run-tests")]
if test_modules != [modules.root]:
command.append("--modules=%s" % ','.join(m.name for m in test_modules))
+ command.append("--parallelism=%i" % parallelism)
run_cmd(command)
@@ -379,7 +381,25 @@ def run_sparkr_tests():
print("Ignoring SparkR tests as R was not found in PATH")
+def parse_opts():
+ parser = OptionParser(
+ prog="run-tests"
+ )
+ parser.add_option(
+ "-p", "--parallelism", type="int", default=4,
+ help="The number of suites to test in parallel (default %default)"
+ )
+
+ (opts, args) = parser.parse_args()
+ if args:
+ parser.error("Unsupported arguments: %s" % ' '.join(args))
+ if opts.parallelism < 1:
+ parser.error("Parallelism cannot be less than 1")
+ return opts
+
+
def main():
+ opts = parse_opts()
# Ensure the user home directory (HOME) is valid and is an absolute directory
if not USER_HOME or not os.path.isabs(USER_HOME):
print("[error] Cannot determine your home directory as an absolute path;",
@@ -461,7 +481,7 @@ def main():
modules_with_python_tests = [m for m in test_modules if m.python_test_goals]
if modules_with_python_tests:
- run_python_tests(modules_with_python_tests)
+ run_python_tests(modules_with_python_tests, opts.parallelism)
if any(m.should_run_r_tests for m in test_modules):
run_sparkr_tests()
diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py
index ad9b0cc89e..12bd0bf3a4 100644
--- a/dev/sparktestsupport/shellutils.py
+++ b/dev/sparktestsupport/shellutils.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+from __future__ import print_function
import os
import shutil
import subprocess
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 3cee4ea6e3..90cd342a6c 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -51,6 +51,8 @@ def launch_gateway():
on_windows = platform.system() == "Windows"
script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+ if os.environ.get("SPARK_TESTING"):
+ submit_args = "--conf spark.ui.enabled=false " + submit_args
command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args)
# Start a socket that will be used by PythonGatewayServer to communicate its port to us
diff --git a/python/run-tests.py b/python/run-tests.py
index 7d485b500e..aaa35e936a 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -18,12 +18,19 @@
#
from __future__ import print_function
+import logging
from optparse import OptionParser
import os
import re
import subprocess
import sys
+import tempfile
+from threading import Thread, Lock
import time
+if sys.version < '3':
+ import Queue
+else:
+ import queue as Queue
# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
@@ -43,34 +50,44 @@ def print_red(text):
LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
+FAILURE_REPORTING_LOCK = Lock()
+LOGGER = logging.getLogger()
def run_individual_python_test(test_name, pyspark_python):
env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}
- print(" Running test: %s ..." % test_name, end='')
+ LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name)
start_time = time.time()
- with open(LOG_FILE, 'a') as log_file:
- retcode = subprocess.call(
- [os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
- stderr=log_file, stdout=log_file, env=env)
+ per_test_output = tempfile.TemporaryFile()
+ retcode = subprocess.Popen(
+ [os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
+ stderr=per_test_output, stdout=per_test_output, env=env).wait()
duration = time.time() - start_time
# Exit on the first failure.
if retcode != 0:
- with open(LOG_FILE, 'r') as log_file:
- for line in log_file:
+ with FAILURE_REPORTING_LOCK:
+ with open(LOG_FILE, 'ab') as log_file:
+ per_test_output.seek(0)
+ log_file.writelines(per_test_output.readlines())
+ per_test_output.seek(0)
+ for line in per_test_output:
if not re.match('[0-9]+', line):
print(line, end='')
- print_red("\nHad test failures in %s; see logs." % test_name)
- exit(-1)
+ per_test_output.close()
+ print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
+ # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
+ # this code is invoked from a thread other than the main thread.
+ os._exit(-1)
else:
- print("ok (%is)" % duration)
+ per_test_output.close()
+ LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
def get_default_python_executables():
python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)]
if "python2.6" not in python_execs:
- print("WARNING: Not testing against `python2.6` because it could not be found; falling"
- " back to `python` instead")
+ LOGGER.warning("Not testing against `python2.6` because it could not be found; falling"
+ " back to `python` instead")
python_execs.insert(0, "python")
return python_execs
@@ -88,16 +105,31 @@ def parse_opts():
default=",".join(sorted(python_modules.keys())),
help="A comma-separated list of Python modules to test (default: %default)"
)
+ parser.add_option(
+ "-p", "--parallelism", type="int", default=4,
+ help="The number of suites to test in parallel (default %default)"
+ )
+ parser.add_option(
+ "--verbose", action="store_true",
+ help="Enable additional debug logging"
+ )
(opts, args) = parser.parse_args()
if args:
parser.error("Unsupported arguments: %s" % ' '.join(args))
+ if opts.parallelism < 1:
+ parser.error("Parallelism cannot be less than 1")
return opts
def main():
opts = parse_opts()
- print("Running PySpark tests. Output is in python/%s" % LOG_FILE)
+ if (opts.verbose):
+ log_level = logging.DEBUG
+ else:
+ log_level = logging.INFO
+ logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
+ LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE)
if os.path.exists(LOG_FILE):
os.remove(LOG_FILE)
python_execs = opts.python_executables.split(',')
@@ -108,24 +140,45 @@ def main():
else:
print("Error: unrecognized module %s" % module_name)
sys.exit(-1)
- print("Will test against the following Python executables: %s" % python_execs)
- print("Will test the following Python modules: %s" % [x.name for x in modules_to_test])
+ LOGGER.info("Will test against the following Python executables: %s", python_execs)
+ LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
- start_time = time.time()
+ task_queue = Queue.Queue()
for python_exec in python_execs:
python_implementation = subprocess.check_output(
[python_exec, "-c", "import platform; print(platform.python_implementation())"],
universal_newlines=True).strip()
- print("Testing with `%s`: " % python_exec, end='')
- subprocess.call([python_exec, "--version"])
-
+ LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation)
+ LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output(
+ [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
for module in modules_to_test:
if python_implementation not in module.blacklisted_python_implementations:
- print("Running %s tests ..." % module.name)
for test_goal in module.python_test_goals:
- run_individual_python_test(test_goal, python_exec)
+ task_queue.put((python_exec, test_goal))
+
+ def process_queue(task_queue):
+ while True:
+ try:
+ (python_exec, test_goal) = task_queue.get_nowait()
+ except Queue.Empty:
+ break
+ try:
+ run_individual_python_test(test_goal, python_exec)
+ finally:
+ task_queue.task_done()
+
+ start_time = time.time()
+ for _ in range(opts.parallelism):
+ worker = Thread(target=process_queue, args=(task_queue,))
+ worker.daemon = True
+ worker.start()
+ try:
+ task_queue.join()
+ except (KeyboardInterrupt, SystemExit):
+ print_red("Exiting due to interrupt")
+ sys.exit(-1)
total_duration = time.time() - start_time
- print("Tests passed in %i seconds" % total_duration)
+ LOGGER.info("Tests passed in %i seconds", total_duration)
if __name__ == "__main__":