aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xdev/lint-python3
-rwxr-xr-xdev/run-tests.py435
-rw-r--r--dev/sparktestsupport/__init__.py21
-rw-r--r--dev/sparktestsupport/modules.py385
-rw-r--r--dev/sparktestsupport/shellutils.py81
-rw-r--r--python/pyspark/streaming/tests.py16
-rw-r--r--python/pyspark/tests.py3
-rwxr-xr-xpython/run-tests164
-rwxr-xr-xpython/run-tests.py132
9 files changed, 700 insertions, 540 deletions
diff --git a/dev/lint-python b/dev/lint-python
index f50d149dc4..0c3586462c 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -19,7 +19,8 @@
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")"
-PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/"
+PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport"
+PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py"
PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt"
cd "$SPARK_ROOT_DIR"
diff --git a/dev/run-tests.py b/dev/run-tests.py
index e7c09b0f40..c51b0d3010 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -17,297 +17,23 @@
# limitations under the License.
#
+from __future__ import print_function
import itertools
import os
import re
import sys
-import shutil
import subprocess
from collections import namedtuple
-SPARK_HOME = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
-USER_HOME = os.environ.get("HOME")
-
+from sparktestsupport import SPARK_HOME, USER_HOME
+from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which
+import sparktestsupport.modules as modules
# -------------------------------------------------------------------------------------------------
-# Test module definitions and functions for traversing module dependency graph
+# Functions for traversing module dependency graph
# -------------------------------------------------------------------------------------------------
-all_modules = []
-
-
-class Module(object):
- """
- A module is the basic abstraction in our test runner script. Each module consists of a set of
- source files, a set of test commands, and a set of dependencies on other modules. We use modules
- to define a dependency graph that lets determine which tests to run based on which files have
- changed.
- """
-
- def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(),
- sbt_test_goals=(), should_run_python_tests=False, should_run_r_tests=False):
- """
- Define a new module.
-
- :param name: A short module name, for display in logging and error messages.
- :param dependencies: A set of dependencies for this module. This should only include direct
- dependencies; transitive dependencies are resolved automatically.
- :param source_file_regexes: a set of regexes that match source files belonging to this
- module. These regexes are applied by attempting to match at the beginning of the
- filename strings.
- :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in
- order to build and test this module (e.g. '-PprofileName').
- :param sbt_test_goals: A set of SBT test goals for testing this module.
- :param should_run_python_tests: If true, changes in this module will trigger Python tests.
- For now, this has the effect of causing _all_ Python tests to be run, although in the
- future this should be changed to run only a subset of the Python tests that depend
- on this module.
- :param should_run_r_tests: If true, changes in this module will trigger all R tests.
- """
- self.name = name
- self.dependencies = dependencies
- self.source_file_prefixes = source_file_regexes
- self.sbt_test_goals = sbt_test_goals
- self.build_profile_flags = build_profile_flags
- self.should_run_python_tests = should_run_python_tests
- self.should_run_r_tests = should_run_r_tests
-
- self.dependent_modules = set()
- for dep in dependencies:
- dep.dependent_modules.add(self)
- all_modules.append(self)
-
- def contains_file(self, filename):
- return any(re.match(p, filename) for p in self.source_file_prefixes)
-
-
-sql = Module(
- name="sql",
- dependencies=[],
- source_file_regexes=[
- "sql/(?!hive-thriftserver)",
- "bin/spark-sql",
- ],
- build_profile_flags=[
- "-Phive",
- ],
- sbt_test_goals=[
- "catalyst/test",
- "sql/test",
- "hive/test",
- ]
-)
-
-
-hive_thriftserver = Module(
- name="hive-thriftserver",
- dependencies=[sql],
- source_file_regexes=[
- "sql/hive-thriftserver",
- "sbin/start-thriftserver.sh",
- ],
- build_profile_flags=[
- "-Phive-thriftserver",
- ],
- sbt_test_goals=[
- "hive-thriftserver/test",
- ]
-)
-
-
-graphx = Module(
- name="graphx",
- dependencies=[],
- source_file_regexes=[
- "graphx/",
- ],
- sbt_test_goals=[
- "graphx/test"
- ]
-)
-
-
-streaming = Module(
- name="streaming",
- dependencies=[],
- source_file_regexes=[
- "streaming",
- ],
- sbt_test_goals=[
- "streaming/test",
- ]
-)
-
-
-streaming_kinesis_asl = Module(
- name="kinesis-asl",
- dependencies=[streaming],
- source_file_regexes=[
- "extras/kinesis-asl/",
- ],
- build_profile_flags=[
- "-Pkinesis-asl",
- ],
- sbt_test_goals=[
- "kinesis-asl/test",
- ]
-)
-
-
-streaming_zeromq = Module(
- name="streaming-zeromq",
- dependencies=[streaming],
- source_file_regexes=[
- "external/zeromq",
- ],
- sbt_test_goals=[
- "streaming-zeromq/test",
- ]
-)
-
-
-streaming_twitter = Module(
- name="streaming-twitter",
- dependencies=[streaming],
- source_file_regexes=[
- "external/twitter",
- ],
- sbt_test_goals=[
- "streaming-twitter/test",
- ]
-)
-
-
-streaming_mqtt = Module(
- name="streaming-mqtt",
- dependencies=[streaming],
- source_file_regexes=[
- "external/mqtt",
- ],
- sbt_test_goals=[
- "streaming-mqtt/test",
- ]
-)
-
-
-streaming_kafka = Module(
- name="streaming-kafka",
- dependencies=[streaming],
- source_file_regexes=[
- "external/kafka",
- "external/kafka-assembly",
- ],
- sbt_test_goals=[
- "streaming-kafka/test",
- ]
-)
-
-
-streaming_flume_sink = Module(
- name="streaming-flume-sink",
- dependencies=[streaming],
- source_file_regexes=[
- "external/flume-sink",
- ],
- sbt_test_goals=[
- "streaming-flume-sink/test",
- ]
-)
-
-
-streaming_flume = Module(
- name="streaming_flume",
- dependencies=[streaming],
- source_file_regexes=[
- "external/flume",
- ],
- sbt_test_goals=[
- "streaming-flume/test",
- ]
-)
-
-
-mllib = Module(
- name="mllib",
- dependencies=[streaming, sql],
- source_file_regexes=[
- "data/mllib/",
- "mllib/",
- ],
- sbt_test_goals=[
- "mllib/test",
- ]
-)
-
-
-examples = Module(
- name="examples",
- dependencies=[graphx, mllib, streaming, sql],
- source_file_regexes=[
- "examples/",
- ],
- sbt_test_goals=[
- "examples/test",
- ]
-)
-
-
-pyspark = Module(
- name="pyspark",
- dependencies=[mllib, streaming, streaming_kafka, sql],
- source_file_regexes=[
- "python/"
- ],
- should_run_python_tests=True
-)
-
-
-sparkr = Module(
- name="sparkr",
- dependencies=[sql, mllib],
- source_file_regexes=[
- "R/",
- ],
- should_run_r_tests=True
-)
-
-
-docs = Module(
- name="docs",
- dependencies=[],
- source_file_regexes=[
- "docs/",
- ]
-)
-
-
-ec2 = Module(
- name="ec2",
- dependencies=[],
- source_file_regexes=[
- "ec2/",
- ]
-)
-
-
-# The root module is a dummy module which is used to run all of the tests.
-# No other modules should directly depend on this module.
-root = Module(
- name="root",
- dependencies=[],
- source_file_regexes=[],
- # In order to run all of the tests, enable every test profile:
- build_profile_flags=
- list(set(itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))),
- sbt_test_goals=[
- "test",
- ],
- should_run_python_tests=True,
- should_run_r_tests=True
-)
-
-
def determine_modules_for_files(filenames):
"""
Given a list of filenames, return the set of modules that contain those files.
@@ -315,19 +41,19 @@ def determine_modules_for_files(filenames):
file to belong to the 'root' module.
>>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/test/foo"]))
- ['pyspark', 'sql']
+ ['pyspark-core', 'sql']
>>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])]
['root']
"""
changed_modules = set()
for filename in filenames:
matched_at_least_one_module = False
- for module in all_modules:
+ for module in modules.all_modules:
if module.contains_file(filename):
changed_modules.add(module)
matched_at_least_one_module = True
if not matched_at_least_one_module:
- changed_modules.add(root)
+ changed_modules.add(modules.root)
return changed_modules
@@ -352,7 +78,8 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe
run_cmd(['git', 'fetch', 'origin', str(target_branch+':'+target_branch)])
else:
diff_target = target_ref
- raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target])
+ raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target],
+ universal_newlines=True)
# Remove any empty strings
return [f for f in raw_output.split('\n') if f]
@@ -362,18 +89,20 @@ def determine_modules_to_test(changed_modules):
Given a set of modules that have changed, compute the transitive closure of those modules'
dependent modules in order to determine the set of modules that should be tested.
- >>> sorted(x.name for x in determine_modules_to_test([root]))
+ >>> sorted(x.name for x in determine_modules_to_test([modules.root]))
['root']
- >>> sorted(x.name for x in determine_modules_to_test([graphx]))
+ >>> sorted(x.name for x in determine_modules_to_test([modules.graphx]))
['examples', 'graphx']
- >>> sorted(x.name for x in determine_modules_to_test([sql]))
- ['examples', 'hive-thriftserver', 'mllib', 'pyspark', 'sparkr', 'sql']
+ >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql]))
+ >>> x # doctest: +NORMALIZE_WHITESPACE
+ ['examples', 'hive-thriftserver', 'mllib', 'pyspark-core', 'pyspark-ml', \
+ 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming', 'sparkr', 'sql']
"""
# If we're going to have to run all of the tests, then we can just short-circuit
# and return 'root'. No module depends on root, so if it appears then it will be
# in changed_modules.
- if root in changed_modules:
- return [root]
+ if modules.root in changed_modules:
+ return [modules.root]
modules_to_test = set()
for module in changed_modules:
modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules))
@@ -398,60 +127,6 @@ def get_error_codes(err_code_file):
ERROR_CODES = get_error_codes(os.path.join(SPARK_HOME, "dev/run-tests-codes.sh"))
-def exit_from_command_with_retcode(cmd, retcode):
- print "[error] running", ' '.join(cmd), "; received return code", retcode
- sys.exit(int(os.environ.get("CURRENT_BLOCK", 255)))
-
-
-def rm_r(path):
- """Given an arbitrary path properly remove it with the correct python
- construct if it exists
- - from: http://stackoverflow.com/a/9559881"""
-
- if os.path.isdir(path):
- shutil.rmtree(path)
- elif os.path.exists(path):
- os.remove(path)
-
-
-def run_cmd(cmd):
- """Given a command as a list of arguments will attempt to execute the
- command from the determined SPARK_HOME directory and, on failure, print
- an error message"""
-
- if not isinstance(cmd, list):
- cmd = cmd.split()
- try:
- subprocess.check_call(cmd)
- except subprocess.CalledProcessError as e:
- exit_from_command_with_retcode(e.cmd, e.returncode)
-
-
-def is_exe(path):
- """Check if a given path is an executable file
- - from: http://stackoverflow.com/a/377028"""
-
- return os.path.isfile(path) and os.access(path, os.X_OK)
-
-
-def which(program):
- """Find and return the given program by its absolute path or 'None'
- - from: http://stackoverflow.com/a/377028"""
-
- fpath = os.path.split(program)[0]
-
- if fpath:
- if is_exe(program):
- return program
- else:
- for path in os.environ.get("PATH").split(os.pathsep):
- path = path.strip('"')
- exe_file = os.path.join(path, program)
- if is_exe(exe_file):
- return exe_file
- return None
-
-
def determine_java_executable():
"""Will return the path of the java executable that will be used by Spark's
tests or `None`"""
@@ -476,7 +151,8 @@ def determine_java_version(java_exe):
with accessors '.major', '.minor', '.patch', '.update'"""
raw_output = subprocess.check_output([java_exe, "-version"],
- stderr=subprocess.STDOUT)
+ stderr=subprocess.STDOUT,
+ universal_newlines=True)
raw_output_lines = raw_output.split('\n')
@@ -504,10 +180,10 @@ def set_title_and_block(title, err_block):
os.environ["CURRENT_BLOCK"] = ERROR_CODES[err_block]
line_str = '=' * 72
- print
- print line_str
- print title
- print line_str
+ print('')
+ print(line_str)
+ print(title)
+ print(line_str)
def run_apache_rat_checks():
@@ -534,8 +210,8 @@ def build_spark_documentation():
jekyll_bin = which("jekyll")
if not jekyll_bin:
- print "[error] Cannot find a version of `jekyll` on the system; please",
- print "install one and retry to build documentation."
+ print("[error] Cannot find a version of `jekyll` on the system; please"
+ " install one and retry to build documentation.")
sys.exit(int(os.environ.get("CURRENT_BLOCK", 255)))
else:
run_cmd([jekyll_bin, "build"])
@@ -571,7 +247,7 @@ def exec_sbt(sbt_args=()):
echo_proc.wait()
for line in iter(sbt_proc.stdout.readline, ''):
if not sbt_output_filter.match(line):
- print line,
+ print(line, end='')
retcode = sbt_proc.wait()
if retcode > 0:
@@ -594,33 +270,33 @@ def get_hadoop_profiles(hadoop_version):
if hadoop_version in sbt_maven_hadoop_profiles:
return sbt_maven_hadoop_profiles[hadoop_version]
else:
- print "[error] Could not find", hadoop_version, "in the list. Valid options",
- print "are", sbt_maven_hadoop_profiles.keys()
+ print("[error] Could not find", hadoop_version, "in the list. Valid options"
+ " are", sbt_maven_hadoop_profiles.keys())
sys.exit(int(os.environ.get("CURRENT_BLOCK", 255)))
def build_spark_maven(hadoop_version):
# Enable all of the profiles for the build:
- build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags
+ build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags
mvn_goals = ["clean", "package", "-DskipTests"]
profiles_and_goals = build_profiles + mvn_goals
- print "[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments:",
- print " ".join(profiles_and_goals)
+ print("[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments: "
+ " ".join(profiles_and_goals))
exec_maven(profiles_and_goals)
def build_spark_sbt(hadoop_version):
# Enable all of the profiles for the build:
- build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags
+ build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags
sbt_goals = ["package",
"assembly/assembly",
"streaming-kafka-assembly/assembly"]
profiles_and_goals = build_profiles + sbt_goals
- print "[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments:",
- print " ".join(profiles_and_goals)
+ print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: "
+ " ".join(profiles_and_goals))
exec_sbt(profiles_and_goals)
@@ -648,8 +324,8 @@ def run_scala_tests_maven(test_profiles):
mvn_test_goals = ["test", "--fail-at-end"]
profiles_and_goals = test_profiles + mvn_test_goals
- print "[info] Running Spark tests using Maven with these arguments:",
- print " ".join(profiles_and_goals)
+ print("[info] Running Spark tests using Maven with these arguments: "
+ " ".join(profiles_and_goals))
exec_maven(profiles_and_goals)
@@ -663,8 +339,8 @@ def run_scala_tests_sbt(test_modules, test_profiles):
profiles_and_goals = test_profiles + list(sbt_test_goals)
- print "[info] Running Spark tests using SBT with these arguments:",
- print " ".join(profiles_and_goals)
+ print("[info] Running Spark tests using SBT with these arguments: "
+ " ".join(profiles_and_goals))
exec_sbt(profiles_and_goals)
@@ -684,10 +360,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules):
run_scala_tests_sbt(test_modules, test_profiles)
-def run_python_tests():
+def run_python_tests(test_modules):
set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS")
- run_cmd([os.path.join(SPARK_HOME, "python", "run-tests")])
+ command = [os.path.join(SPARK_HOME, "python", "run-tests")]
+ if test_modules != [modules.root]:
+ command.append("--modules=%s" % ','.join(m.name for m in modules))
+ run_cmd(command)
def run_sparkr_tests():
@@ -697,14 +376,14 @@ def run_sparkr_tests():
run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")])
run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")])
else:
- print "Ignoring SparkR tests as R was not found in PATH"
+ print("Ignoring SparkR tests as R was not found in PATH")
def main():
# Ensure the user home directory (HOME) is valid and is an absolute directory
if not USER_HOME or not os.path.isabs(USER_HOME):
- print "[error] Cannot determine your home directory as an absolute path;",
- print "ensure the $HOME environment variable is set properly."
+ print("[error] Cannot determine your home directory as an absolute path;"
+ " ensure the $HOME environment variable is set properly.")
sys.exit(1)
os.chdir(SPARK_HOME)
@@ -718,14 +397,14 @@ def main():
java_exe = determine_java_executable()
if not java_exe:
- print "[error] Cannot find a version of `java` on the system; please",
- print "install one and retry."
+ print("[error] Cannot find a version of `java` on the system; please"
+ " install one and retry.")
sys.exit(2)
java_version = determine_java_version(java_exe)
if java_version.minor < 8:
- print "[warn] Java 8 tests will not run because JDK version is < 1.8."
+ print("[warn] Java 8 tests will not run because JDK version is < 1.8.")
if os.environ.get("AMPLAB_JENKINS"):
# if we're on the Amplab Jenkins build servers setup variables
@@ -741,8 +420,8 @@ def main():
hadoop_version = "hadoop2.3"
test_env = "local"
- print "[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version,
- print "under environment", test_env
+ print("[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version,
+ "under environment", test_env)
changed_modules = None
changed_files = None
@@ -751,8 +430,9 @@ def main():
changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch)
changed_modules = determine_modules_for_files(changed_files)
if not changed_modules:
- changed_modules = [root]
- print "[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)
+ changed_modules = [modules.root]
+ print("[info] Found the following changed modules:",
+ ", ".join(x.name for x in changed_modules))
test_modules = determine_modules_to_test(changed_modules)
@@ -779,8 +459,9 @@ def main():
# run the test suites
run_scala_tests(build_tool, hadoop_version, test_modules)
- if any(m.should_run_python_tests for m in test_modules):
- run_python_tests()
+ modules_with_python_tests = [m for m in test_modules if m.python_test_goals]
+ if modules_with_python_tests:
+ run_python_tests(modules_with_python_tests)
if any(m.should_run_r_tests for m in test_modules):
run_sparkr_tests()
diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py
new file mode 100644
index 0000000000..12696d98fb
--- /dev/null
+++ b/dev/sparktestsupport/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+SPARK_HOME = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../"))
+USER_HOME = os.environ.get("HOME")
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
new file mode 100644
index 0000000000..efe3a897e9
--- /dev/null
+++ b/dev/sparktestsupport/modules.py
@@ -0,0 +1,385 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import itertools
+import re
+
+all_modules = []
+
+
+class Module(object):
+ """
+ A module is the basic abstraction in our test runner script. Each module consists of a set of
+ source files, a set of test commands, and a set of dependencies on other modules. We use modules
+ to define a dependency graph that lets determine which tests to run based on which files have
+ changed.
+ """
+
+ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(),
+ sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(),
+ should_run_r_tests=False):
+ """
+ Define a new module.
+
+ :param name: A short module name, for display in logging and error messages.
+ :param dependencies: A set of dependencies for this module. This should only include direct
+ dependencies; transitive dependencies are resolved automatically.
+ :param source_file_regexes: a set of regexes that match source files belonging to this
+ module. These regexes are applied by attempting to match at the beginning of the
+ filename strings.
+ :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in
+ order to build and test this module (e.g. '-PprofileName').
+ :param sbt_test_goals: A set of SBT test goals for testing this module.
+ :param python_test_goals: A set of Python test goals for testing this module.
+ :param blacklisted_python_implementations: A set of Python implementations that are not
+ supported by this module's Python components. The values in this set should match
+ strings returned by Python's `platform.python_implementation()`.
+ :param should_run_r_tests: If true, changes in this module will trigger all R tests.
+ """
+ self.name = name
+ self.dependencies = dependencies
+ self.source_file_prefixes = source_file_regexes
+ self.sbt_test_goals = sbt_test_goals
+ self.build_profile_flags = build_profile_flags
+ self.python_test_goals = python_test_goals
+ self.blacklisted_python_implementations = blacklisted_python_implementations
+ self.should_run_r_tests = should_run_r_tests
+
+ self.dependent_modules = set()
+ for dep in dependencies:
+ dep.dependent_modules.add(self)
+ all_modules.append(self)
+
+ def contains_file(self, filename):
+ return any(re.match(p, filename) for p in self.source_file_prefixes)
+
+
+sql = Module(
+ name="sql",
+ dependencies=[],
+ source_file_regexes=[
+ "sql/(?!hive-thriftserver)",
+ "bin/spark-sql",
+ ],
+ build_profile_flags=[
+ "-Phive",
+ ],
+ sbt_test_goals=[
+ "catalyst/test",
+ "sql/test",
+ "hive/test",
+ ]
+)
+
+
+hive_thriftserver = Module(
+ name="hive-thriftserver",
+ dependencies=[sql],
+ source_file_regexes=[
+ "sql/hive-thriftserver",
+ "sbin/start-thriftserver.sh",
+ ],
+ build_profile_flags=[
+ "-Phive-thriftserver",
+ ],
+ sbt_test_goals=[
+ "hive-thriftserver/test",
+ ]
+)
+
+
+graphx = Module(
+ name="graphx",
+ dependencies=[],
+ source_file_regexes=[
+ "graphx/",
+ ],
+ sbt_test_goals=[
+ "graphx/test"
+ ]
+)
+
+
+streaming = Module(
+ name="streaming",
+ dependencies=[],
+ source_file_regexes=[
+ "streaming",
+ ],
+ sbt_test_goals=[
+ "streaming/test",
+ ]
+)
+
+
+streaming_kinesis_asl = Module(
+ name="kinesis-asl",
+ dependencies=[streaming],
+ source_file_regexes=[
+ "extras/kinesis-asl/",
+ ],
+ build_profile_flags=[
+ "-Pkinesis-asl",
+ ],
+ sbt_test_goals=[
+ "kinesis-asl/test",
+ ]
+)
+
+
+streaming_zeromq = Module(
+ name="streaming-zeromq",
+ dependencies=[streaming],
+ source_file_regexes=[
+ "external/zeromq",
+ ],
+ sbt_test_goals=[
+ "streaming-zeromq/test",
+ ]
+)
+
+
+streaming_twitter = Module(
+ name="streaming-twitter",
+ dependencies=[streaming],
+ source_file_regexes=[
+ "external/twitter",
+ ],
+ sbt_test_goals=[
+ "streaming-twitter/test",
+ ]
+)
+
+
+streaming_mqtt = Module(
+ name="streaming-mqtt",
+ dependencies=[streaming],
+ source_file_regexes=[
+ "external/mqtt",
+ ],
+ sbt_test_goals=[
+ "streaming-mqtt/test",
+ ]
+)
+
+
+streaming_kafka = Module(
+ name="streaming-kafka",
+ dependencies=[streaming],
+ source_file_regexes=[
+ "external/kafka",
+ "external/kafka-assembly",
+ ],
+ sbt_test_goals=[
+ "streaming-kafka/test",
+ ]
+)
+
+
+streaming_flume_sink = Module(
+ name="streaming-flume-sink",
+ dependencies=[streaming],
+ source_file_regexes=[
+ "external/flume-sink",
+ ],
+ sbt_test_goals=[
+ "streaming-flume-sink/test",
+ ]
+)
+
+
+streaming_flume = Module(
+ name="streaming_flume",
+ dependencies=[streaming],
+ source_file_regexes=[
+ "external/flume",
+ ],
+ sbt_test_goals=[
+ "streaming-flume/test",
+ ]
+)
+
+
+mllib = Module(
+ name="mllib",
+ dependencies=[streaming, sql],
+ source_file_regexes=[
+ "data/mllib/",
+ "mllib/",
+ ],
+ sbt_test_goals=[
+ "mllib/test",
+ ]
+)
+
+
+examples = Module(
+ name="examples",
+ dependencies=[graphx, mllib, streaming, sql],
+ source_file_regexes=[
+ "examples/",
+ ],
+ sbt_test_goals=[
+ "examples/test",
+ ]
+)
+
+
+pyspark_core = Module(
+ name="pyspark-core",
+ dependencies=[mllib, streaming, streaming_kafka],
+ source_file_regexes=[
+ "python/(?!pyspark/(ml|mllib|sql|streaming))"
+ ],
+ python_test_goals=[
+ "pyspark.rdd",
+ "pyspark.context",
+ "pyspark.conf",
+ "pyspark.broadcast",
+ "pyspark.accumulators",
+ "pyspark.serializers",
+ "pyspark.profiler",
+ "pyspark.shuffle",
+ "pyspark.tests",
+ ]
+)
+
+
+pyspark_sql = Module(
+ name="pyspark-sql",
+ dependencies=[pyspark_core, sql],
+ source_file_regexes=[
+ "python/pyspark/sql"
+ ],
+ python_test_goals=[
+ "pyspark.sql.types",
+ "pyspark.sql.context",
+ "pyspark.sql.column",
+ "pyspark.sql.dataframe",
+ "pyspark.sql.group",
+ "pyspark.sql.functions",
+ "pyspark.sql.readwriter",
+ "pyspark.sql.window",
+ "pyspark.sql.tests",
+ ]
+)
+
+
+pyspark_streaming = Module(
+ name="pyspark-streaming",
+ dependencies=[pyspark_core, streaming, streaming_kafka],
+ source_file_regexes=[
+ "python/pyspark/streaming"
+ ],
+ python_test_goals=[
+ "pyspark.streaming.util",
+ "pyspark.streaming.tests",
+ ]
+)
+
+
+pyspark_mllib = Module(
+ name="pyspark-mllib",
+ dependencies=[pyspark_core, pyspark_streaming, pyspark_sql, mllib],
+ source_file_regexes=[
+ "python/pyspark/mllib"
+ ],
+ python_test_goals=[
+ "pyspark.mllib.classification",
+ "pyspark.mllib.clustering",
+ "pyspark.mllib.evaluation",
+ "pyspark.mllib.feature",
+ "pyspark.mllib.fpm",
+ "pyspark.mllib.linalg",
+ "pyspark.mllib.random",
+ "pyspark.mllib.recommendation",
+ "pyspark.mllib.regression",
+ "pyspark.mllib.stat._statistics",
+ "pyspark.mllib.stat.KernelDensity",
+ "pyspark.mllib.tree",
+ "pyspark.mllib.util",
+ "pyspark.mllib.tests",
+ ],
+ blacklisted_python_implementations=[
+ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
+ ]
+)
+
+
+pyspark_ml = Module(
+ name="pyspark-ml",
+ dependencies=[pyspark_core, pyspark_mllib],
+ source_file_regexes=[
+ "python/pyspark/ml/"
+ ],
+ python_test_goals=[
+ "pyspark.ml.feature",
+ "pyspark.ml.classification",
+ "pyspark.ml.recommendation",
+ "pyspark.ml.regression",
+ "pyspark.ml.tuning",
+ "pyspark.ml.tests",
+ "pyspark.ml.evaluation",
+ ],
+ blacklisted_python_implementations=[
+ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
+ ]
+)
+
+sparkr = Module(
+ name="sparkr",
+ dependencies=[sql, mllib],
+ source_file_regexes=[
+ "R/",
+ ],
+ should_run_r_tests=True
+)
+
+
+docs = Module(
+ name="docs",
+ dependencies=[],
+ source_file_regexes=[
+ "docs/",
+ ]
+)
+
+
+ec2 = Module(
+ name="ec2",
+ dependencies=[],
+ source_file_regexes=[
+ "ec2/",
+ ]
+)
+
+
+# The root module is a dummy module which is used to run all of the tests.
+# No other modules should directly depend on this module.
+root = Module(
+ name="root",
+ dependencies=[],
+ source_file_regexes=[],
+ # In order to run all of the tests, enable every test profile:
+ build_profile_flags=list(set(
+ itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))),
+ sbt_test_goals=[
+ "test",
+ ],
+ python_test_goals=list(itertools.chain.from_iterable(m.python_test_goals for m in all_modules)),
+ should_run_r_tests=True
+)
diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py
new file mode 100644
index 0000000000..ad9b0cc89e
--- /dev/null
+++ b/dev/sparktestsupport/shellutils.py
@@ -0,0 +1,81 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import subprocess
+import sys
+
+
+def exit_from_command_with_retcode(cmd, retcode):
+ print("[error] running", ' '.join(cmd), "; received return code", retcode)
+ sys.exit(int(os.environ.get("CURRENT_BLOCK", 255)))
+
+
+def rm_r(path):
+ """
+ Given an arbitrary path, properly remove it with the correct Python construct if it exists.
+ From: http://stackoverflow.com/a/9559881
+ """
+
+ if os.path.isdir(path):
+ shutil.rmtree(path)
+ elif os.path.exists(path):
+ os.remove(path)
+
+
+def run_cmd(cmd):
+ """
+ Given a command as a list of arguments will attempt to execute the command
+ and, on failure, print an error message and exit.
+ """
+
+ if not isinstance(cmd, list):
+ cmd = cmd.split()
+ try:
+ subprocess.check_call(cmd)
+ except subprocess.CalledProcessError as e:
+ exit_from_command_with_retcode(e.cmd, e.returncode)
+
+
+def is_exe(path):
+ """
+ Check if a given path is an executable file.
+ From: http://stackoverflow.com/a/377028
+ """
+
+ return os.path.isfile(path) and os.access(path, os.X_OK)
+
+
+def which(program):
+ """
+ Find and return the given program by its absolute path or 'None' if the program cannot be found.
+ From: http://stackoverflow.com/a/377028
+ """
+
+ fpath = os.path.split(program)[0]
+
+ if fpath:
+ if is_exe(program):
+ return program
+ else:
+ for path in os.environ.get("PATH").split(os.pathsep):
+ path = path.strip('"')
+ exe_file = os.path.join(path, program)
+ if is_exe(exe_file):
+ return exe_file
+ return None
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 57049beea4..91ce681fbe 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import glob
import os
import sys
from itertools import chain
@@ -677,4 +678,19 @@ class KafkaStreamTests(PySparkStreamingTestCase):
self._validateRddResult(sendData, rdd)
if __name__ == "__main__":
+ SPARK_HOME = os.environ["SPARK_HOME"]
+ kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly")
+ jars = glob.glob(
+ os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar"))
+ if not jars:
+ raise Exception(
+ ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) +
+ "You need to build Spark with "
+ "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or "
+ "'build/mvn package' before running this test")
+ elif len(jars) > 1:
+ raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please "
+ "remove all but one") % kafka_assembly_dir)
+ else:
+ os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0]
unittest.main()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 7826542368..17256dfc95 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1421,7 +1421,8 @@ class DaemonTests(unittest.TestCase):
# start daemon
daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
- daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
+ python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON")
+ daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE)
# read the port number
port = read_int(daemon.stdout)
diff --git a/python/run-tests b/python/run-tests
index 4468fdb3f2..24949657ed 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -18,165 +18,7 @@
#
-# Figure out where the Spark framework is installed
-FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)"
+FWDIR="$(cd "`dirname $0`"/..; pwd)"
+cd "$FWDIR"
-. "$FWDIR"/bin/load-spark-env.sh
-
-# CD into the python directory to find things on the right path
-cd "$FWDIR/python"
-
-FAILED=0
-LOG_FILE=unit-tests.log
-START=$(date +"%s")
-
-rm -f $LOG_FILE
-
-# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL
-rm -rf metastore warehouse
-
-function run_test() {
- echo -en "Running test: $1 ... " | tee -a $LOG_FILE
- start=$(date +"%s")
- SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1
-
- FAILED=$((PIPESTATUS[0]||$FAILED))
-
- # Fail and exit on the first test failure.
- if [[ $FAILED != 0 ]]; then
- cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number.
- echo -en "\033[31m" # Red
- echo "Had test failures; see logs."
- echo -en "\033[0m" # No color
- exit -1
- else
- now=$(date +"%s")
- echo "ok ($(($now - $start))s)"
- fi
-}
-
-function run_core_tests() {
- echo "Run core tests ..."
- run_test "pyspark.rdd"
- run_test "pyspark.context"
- run_test "pyspark.conf"
- run_test "pyspark.broadcast"
- run_test "pyspark.accumulators"
- run_test "pyspark.serializers"
- run_test "pyspark.profiler"
- run_test "pyspark.shuffle"
- run_test "pyspark.tests"
-}
-
-function run_sql_tests() {
- echo "Run sql tests ..."
- run_test "pyspark.sql.types"
- run_test "pyspark.sql.context"
- run_test "pyspark.sql.column"
- run_test "pyspark.sql.dataframe"
- run_test "pyspark.sql.group"
- run_test "pyspark.sql.functions"
- run_test "pyspark.sql.readwriter"
- run_test "pyspark.sql.window"
- run_test "pyspark.sql.tests"
-}
-
-function run_mllib_tests() {
- echo "Run mllib tests ..."
- run_test "pyspark.mllib.classification"
- run_test "pyspark.mllib.clustering"
- run_test "pyspark.mllib.evaluation"
- run_test "pyspark.mllib.feature"
- run_test "pyspark.mllib.fpm"
- run_test "pyspark.mllib.linalg"
- run_test "pyspark.mllib.random"
- run_test "pyspark.mllib.recommendation"
- run_test "pyspark.mllib.regression"
- run_test "pyspark.mllib.stat._statistics"
- run_test "pyspark.mllib.stat.KernelDensity"
- run_test "pyspark.mllib.tree"
- run_test "pyspark.mllib.util"
- run_test "pyspark.mllib.tests"
-}
-
-function run_ml_tests() {
- echo "Run ml tests ..."
- run_test "pyspark.ml.feature"
- run_test "pyspark.ml.classification"
- run_test "pyspark.ml.recommendation"
- run_test "pyspark.ml.regression"
- run_test "pyspark.ml.tuning"
- run_test "pyspark.ml.tests"
- run_test "pyspark.ml.evaluation"
-}
-
-function run_streaming_tests() {
- echo "Run streaming tests ..."
-
- KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly
- JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}"
- for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do
- if [[ ! -e "$f" ]]; then
- echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2
- echo "You need to build Spark with " \
- "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \
- "'build/mvn package' before running this program" 1>&2
- exit 1
- fi
- KAFKA_ASSEMBLY_JAR="$f"
- done
-
- export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell"
- run_test "pyspark.streaming.util"
- run_test "pyspark.streaming.tests"
-}
-
-echo "Running PySpark tests. Output is in python/$LOG_FILE."
-
-export PYSPARK_PYTHON="python"
-
-# Try to test with Python 2.6, since that's the minimum version that we support:
-if [ $(which python2.6) ]; then
- export PYSPARK_PYTHON="python2.6"
-fi
-
-echo "Testing with Python version:"
-$PYSPARK_PYTHON --version
-
-run_core_tests
-run_sql_tests
-run_mllib_tests
-run_ml_tests
-run_streaming_tests
-
-# Try to test with Python 3
-if [ $(which python3.4) ]; then
- export PYSPARK_PYTHON="python3.4"
- echo "Testing with Python3.4 version:"
- $PYSPARK_PYTHON --version
-
- run_core_tests
- run_sql_tests
- run_mllib_tests
- run_ml_tests
- run_streaming_tests
-fi
-
-# Try to test with PyPy
-if [ $(which pypy) ]; then
- export PYSPARK_PYTHON="pypy"
- echo "Testing with PyPy version:"
- $PYSPARK_PYTHON --version
-
- run_core_tests
- run_sql_tests
- run_streaming_tests
-fi
-
-if [[ $FAILED == 0 ]]; then
- now=$(date +"%s")
- echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds"
-fi
-
-# TODO: in the long-run, it would be nice to use a test runner like `nose`.
-# The doctest fixtures are the current barrier to doing this.
+exec python -u ./python/run-tests.py "$@"
diff --git a/python/run-tests.py b/python/run-tests.py
new file mode 100755
index 0000000000..7d485b500e
--- /dev/null
+++ b/python/run-tests.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+from optparse import OptionParser
+import os
+import re
+import subprocess
+import sys
+import time
+
+
+# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/"))
+
+
+from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings)
+from sparktestsupport.shellutils import which # noqa
+from sparktestsupport.modules import all_modules # noqa
+
+
+python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root')
+
+
+def print_red(text):
+ print('\033[31m' + text + '\033[0m')
+
+
+LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
+
+
+def run_individual_python_test(test_name, pyspark_python):
+ env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}
+ print(" Running test: %s ..." % test_name, end='')
+ start_time = time.time()
+ with open(LOG_FILE, 'a') as log_file:
+ retcode = subprocess.call(
+ [os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
+ stderr=log_file, stdout=log_file, env=env)
+ duration = time.time() - start_time
+ # Exit on the first failure.
+ if retcode != 0:
+ with open(LOG_FILE, 'r') as log_file:
+ for line in log_file:
+ if not re.match('[0-9]+', line):
+ print(line, end='')
+ print_red("\nHad test failures in %s; see logs." % test_name)
+ exit(-1)
+ else:
+ print("ok (%is)" % duration)
+
+
+def get_default_python_executables():
+ python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)]
+ if "python2.6" not in python_execs:
+ print("WARNING: Not testing against `python2.6` because it could not be found; falling"
+ " back to `python` instead")
+ python_execs.insert(0, "python")
+ return python_execs
+
+
+def parse_opts():
+ parser = OptionParser(
+ prog="run-tests"
+ )
+ parser.add_option(
+ "--python-executables", type="string", default=','.join(get_default_python_executables()),
+ help="A comma-separated list of Python executables to test against (default: %default)"
+ )
+ parser.add_option(
+ "--modules", type="string",
+ default=",".join(sorted(python_modules.keys())),
+ help="A comma-separated list of Python modules to test (default: %default)"
+ )
+
+ (opts, args) = parser.parse_args()
+ if args:
+ parser.error("Unsupported arguments: %s" % ' '.join(args))
+ return opts
+
+
+def main():
+ opts = parse_opts()
+ print("Running PySpark tests. Output is in python/%s" % LOG_FILE)
+ if os.path.exists(LOG_FILE):
+ os.remove(LOG_FILE)
+ python_execs = opts.python_executables.split(',')
+ modules_to_test = []
+ for module_name in opts.modules.split(','):
+ if module_name in python_modules:
+ modules_to_test.append(python_modules[module_name])
+ else:
+ print("Error: unrecognized module %s" % module_name)
+ sys.exit(-1)
+ print("Will test against the following Python executables: %s" % python_execs)
+ print("Will test the following Python modules: %s" % [x.name for x in modules_to_test])
+
+ start_time = time.time()
+ for python_exec in python_execs:
+ python_implementation = subprocess.check_output(
+ [python_exec, "-c", "import platform; print(platform.python_implementation())"],
+ universal_newlines=True).strip()
+ print("Testing with `%s`: " % python_exec, end='')
+ subprocess.call([python_exec, "--version"])
+
+ for module in modules_to_test:
+ if python_implementation not in module.blacklisted_python_implementations:
+ print("Running %s tests ..." % module.name)
+ for test_goal in module.python_test_goals:
+ run_individual_python_test(test_goal, python_exec)
+ total_duration = time.time() - start_time
+ print("Tests passed in %i seconds" % total_duration)
+
+
+if __name__ == "__main__":
+ main()