aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/java_gateway.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/java_gateway.py')
-rw-r--r--python/pyspark/java_gateway.py29
1 files changed, 27 insertions, 2 deletions
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 6bb6c877c9..032d960e40 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -24,10 +24,11 @@ from threading import Thread
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
-SPARK_HOME = os.environ["SPARK_HOME"]
+def launch_gateway():
+ SPARK_HOME = os.environ["SPARK_HOME"]
+ set_env_vars_for_yarn()
-def launch_gateway():
# Launch the Py4j gateway using Spark's run command so that we pick up the
# proper classpath and settings from spark-env.sh
on_windows = platform.system() == "Windows"
@@ -70,3 +71,27 @@ def launch_gateway():
java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
+
+def set_env_vars_for_yarn():
+ # Add the spark jar, which includes the pyspark files, to the python path
+ env_map = parse_env(os.environ.get("SPARK_YARN_USER_ENV", ""))
+ if "PYTHONPATH" in env_map:
+ env_map["PYTHONPATH"] += ":spark.jar"
+ else:
+ env_map["PYTHONPATH"] = "spark.jar"
+
+ os.environ["SPARK_YARN_USER_ENV"] = ",".join(k + '=' + v for (k, v) in env_map.items())
+
+def parse_env(env_str):
+ # Turns a comma-separated of env settings into a dict that maps env vars to
+ # their values.
+ env = {}
+ for var_str in env_str.split(","):
+ parts = var_str.split("=")
+ if len(parts) == 2:
+ env[parts[0]] = parts[1]
+ elif len(var_str) > 0:
+ print "Invalid entry in SPARK_YARN_USER_ENV: " + var_str
+ sys.exit(1)
+
+ return env