diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/context.py | 6 | ||||
-rw-r--r-- | python/pyspark/java_gateway.py | 89 | ||||
-rw-r--r-- | python/pyspark/tests.py | 131 |
3 files changed, 168 insertions, 58 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index c74dc5fd4f..c7dc85ea03 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -158,6 +158,12 @@ class SparkContext(object): for path in (pyFiles or []): self.addPyFile(path) + # Deploy code dependencies set by spark-submit; these will already have been added + # with SparkContext.addFile, so we just need to add them + for path in self._conf.get("spark.submit.pyFiles", "").split(","): + if path != "": + self._python_includes.append(os.path.basename(path)) + # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) self._temp_dir = \ diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 032d960e40..3d0936fdca 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -27,39 +27,43 @@ from py4j.java_gateway import java_import, JavaGateway, GatewayClient def launch_gateway(): SPARK_HOME = os.environ["SPARK_HOME"] - set_env_vars_for_yarn() - - # Launch the Py4j gateway using Spark's run command so that we pick up the - # proper classpath and settings from spark-env.sh - on_windows = platform.system() == "Windows" - script = "./bin/spark-class.cmd" if on_windows else "./bin/spark-class" - command = [os.path.join(SPARK_HOME, script), "py4j.GatewayServer", - "--die-on-broken-pipe", "0"] - if not on_windows: - # Don't send ctrl-c / SIGINT to the Java gateway: - def preexec_func(): - signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + gateway_port = -1 + if "PYSPARK_GATEWAY_PORT" in os.environ: + gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: - # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE) - # Determine which ephemeral port the server started on: - port = int(proc.stdout.readline()) - # Create a thread to echo output from the GatewayServer, which is required - # for Java log output to show up: - class EchoOutputThread(Thread): - def __init__(self, stream): - Thread.__init__(self) - self.daemon = True - self.stream = stream + # Launch the Py4j gateway using Spark's run command so that we pick up the + # proper classpath and settings from spark-env.sh + on_windows = platform.system() == "Windows" + script = "./bin/spark-class.cmd" if on_windows else "./bin/spark-class" + command = [os.path.join(SPARK_HOME, script), "py4j.GatewayServer", + "--die-on-broken-pipe", "0"] + if not on_windows: + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_func(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + else: + # preexec_fn not supported on Windows + proc = Popen(command, stdout=PIPE, stdin=PIPE) + # Determine which ephemeral port the server started on: + gateway_port = int(proc.stdout.readline()) + # Create a thread to echo output from the GatewayServer, which is required + # for Java log output to show up: + class EchoOutputThread(Thread): + def __init__(self, stream): + Thread.__init__(self) + self.daemon = True + self.stream = stream + + def run(self): + while True: + line = self.stream.readline() + sys.stderr.write(line) + EchoOutputThread(proc.stdout).start() - def run(self): - while True: - line = self.stream.readline() - sys.stderr.write(line) - EchoOutputThread(proc.stdout).start() # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) + gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) + # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") @@ -70,28 +74,5 @@ def launch_gateway(): java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") java_import(gateway.jvm, "scala.Tuple2") - return gateway -def set_env_vars_for_yarn(): - # Add the spark jar, which includes the pyspark files, to the python path - env_map = parse_env(os.environ.get("SPARK_YARN_USER_ENV", "")) - if "PYTHONPATH" in env_map: - env_map["PYTHONPATH"] += ":spark.jar" - else: - env_map["PYTHONPATH"] = "spark.jar" - - os.environ["SPARK_YARN_USER_ENV"] = ",".join(k + '=' + v for (k, v) in env_map.items()) - -def parse_env(env_str): - # Turns a comma-separated of env settings into a dict that maps env vars to - # their values. - env = {} - for var_str in env_str.split(","): - parts = var_str.split("=") - if len(parts) == 2: - env[parts[0]] = parts[1] - elif len(var_str) > 0: - print "Invalid entry in SPARK_YARN_USER_ENV: " + var_str - sys.exit(1) - - return env + return gateway diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8cf9d9cf1b..64f2eeb12b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -22,11 +22,14 @@ individual modules. from fileinput import input from glob import glob import os +import re import shutil +import subprocess import sys -from tempfile import NamedTemporaryFile +import tempfile import time import unittest +import zipfile from pyspark.context import SparkContext from pyspark.files import SparkFiles @@ -55,7 +58,7 @@ class TestCheckpoint(PySparkTestCase): def setUp(self): PySparkTestCase.setUp(self) - self.checkpointDir = NamedTemporaryFile(delete=False) + self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) os.unlink(self.checkpointDir.name) self.sc.setCheckpointDir(self.checkpointDir.name) @@ -148,7 +151,7 @@ class TestRDDFunctions(PySparkTestCase): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" data = self.sc.parallelize([x]) - tempFile = NamedTemporaryFile(delete=True) + tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsTextFile(tempFile.name) raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) @@ -172,7 +175,7 @@ class TestRDDFunctions(PySparkTestCase): def test_deleting_input_files(self): # Regression test for SPARK-1025 - tempFile = NamedTemporaryFile(delete=False) + tempFile = tempfile.NamedTemporaryFile(delete=False) tempFile.write("Hello World!") tempFile.close() data = self.sc.textFile(tempFile.name) @@ -236,5 +239,125 @@ class TestDaemon(unittest.TestCase): from signal import SIGTERM self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) + +class TestSparkSubmit(unittest.TestCase): + def setUp(self): + self.programDir = tempfile.mkdtemp() + self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") + + def tearDown(self): + shutil.rmtree(self.programDir) + + def createTempFile(self, name, content): + """ + Create a temp file with the given name and content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + path = os.path.join(self.programDir, name) + with open(path, "w") as f: + f.write(content) + return path + + def createFileInZip(self, name, content): + """ + Create a zip archive containing a file with the given content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + path = os.path.join(self.programDir, name + ".zip") + with zipfile.ZipFile(path, 'w') as zip: + zip.writestr(name, content) + return path + + def test_single_script(self): + """Submit and test a single script file""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |sc = SparkContext() + |print sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect() + """) + proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out) + + def test_script_with_local_functions(self): + """Submit and test a single script file calling a global function""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 3 + | + |sc = SparkContext() + |print sc.parallelize([1, 2, 3]).map(foo).collect() + """) + proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[3, 6, 9]", out) + + def test_module_dependency(self): + """Submit and test a script with a dependency on another module""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out) + + def test_module_dependency_on_cluster(self): + """Submit and test a script with a dependency on another module on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen( + [self.sparkSubmit, "--py-files", zip, "--master", "local-cluster[1,1,512]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out) + + def test_single_script_on_cluster(self): + """Submit and test a single script on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 2 + | + |sc = SparkContext() + |print sc.parallelize([1, 2, 3]).map(foo).collect() + """) + proc = subprocess.Popen( + [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out) + + if __name__ == "__main__": unittest.main() |