aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorSandy Ryza <sandy@cloudera.com>2014-04-29 23:24:34 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-29 23:24:34 -0700
commitff5be9a41e52454e0f9cae83dd1fd50fbeaa684a (patch)
tree5bd17eaa50b3120317370821104c9c6d9e238b56 /python
parent7025dda8fa84b57d6f12bc770df2fa10eef21d88 (diff)
downloadspark-ff5be9a41e52454e0f9cae83dd1fd50fbeaa684a.tar.gz
spark-ff5be9a41e52454e0f9cae83dd1fd50fbeaa684a.tar.bz2
spark-ff5be9a41e52454e0f9cae83dd1fd50fbeaa684a.zip
SPARK-1004. PySpark on YARN
This reopens https://github.com/apache/incubator-spark/pull/640 against the new repo Author: Sandy Ryza <sandy@cloudera.com> Closes #30 from sryza/sandy-spark-1004 and squashes the following commits: 89889d4 [Sandy Ryza] Move unzipping py4j to the generate-resources phase so that it gets included in the jar the first time 5165a02 [Sandy Ryza] Fix docs fd0df79 [Sandy Ryza] PySpark on YARN
Diffstat (limited to 'python')
-rw-r--r--python/.gitignore3
-rw-r--r--python/lib/PY4J_VERSION.txt1
-rw-r--r--python/pyspark/__init__.py7
-rw-r--r--python/pyspark/java_gateway.py29
-rw-r--r--python/pyspark/tests.py4
5 files changed, 33 insertions, 11 deletions
diff --git a/python/.gitignore b/python/.gitignore
index 5c56e638f9..80b361ffbd 100644
--- a/python/.gitignore
+++ b/python/.gitignore
@@ -1,2 +1,5 @@
*.pyc
docs/
+pyspark.egg-info
+build/
+dist/
diff --git a/python/lib/PY4J_VERSION.txt b/python/lib/PY4J_VERSION.txt
deleted file mode 100644
index 04a0cd52a8..0000000000
--- a/python/lib/PY4J_VERSION.txt
+++ /dev/null
@@ -1 +0,0 @@
-b7924aabe9c5e63f0a4d8bbd17019534c7ec014e
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 73fe7378ff..07df8697bd 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -49,13 +49,6 @@ Hive:
Main entry point for accessing data stored in Apache Hive..
"""
-
-
-import sys
-import os
-sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j-0.8.1-src.zip"))
-
-
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 6bb6c877c9..032d960e40 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -24,10 +24,11 @@ from threading import Thread
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
-SPARK_HOME = os.environ["SPARK_HOME"]
+def launch_gateway():
+ SPARK_HOME = os.environ["SPARK_HOME"]
+ set_env_vars_for_yarn()
-def launch_gateway():
# Launch the Py4j gateway using Spark's run command so that we pick up the
# proper classpath and settings from spark-env.sh
on_windows = platform.system() == "Windows"
@@ -70,3 +71,27 @@ def launch_gateway():
java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
+
+def set_env_vars_for_yarn():
+ # Add the spark jar, which includes the pyspark files, to the python path
+ env_map = parse_env(os.environ.get("SPARK_YARN_USER_ENV", ""))
+ if "PYTHONPATH" in env_map:
+ env_map["PYTHONPATH"] += ":spark.jar"
+ else:
+ env_map["PYTHONPATH"] = "spark.jar"
+
+ os.environ["SPARK_YARN_USER_ENV"] = ",".join(k + '=' + v for (k, v) in env_map.items())
+
+def parse_env(env_str):
+ # Turns a comma-separated of env settings into a dict that maps env vars to
+ # their values.
+ env = {}
+ for var_str in env_str.split(","):
+ parts = var_str.split("=")
+ if len(parts) == 2:
+ env[parts[0]] = parts[1]
+ elif len(var_str) > 0:
+ print "Invalid entry in SPARK_YARN_USER_ENV: " + var_str
+ sys.exit(1)
+
+ return env
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 527104587f..8cf9d9cf1b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -30,10 +30,12 @@ import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
-from pyspark.java_gateway import SPARK_HOME
from pyspark.serializers import read_int
+SPARK_HOME = os.environ["SPARK_HOME"]
+
+
class PySparkTestCase(unittest.TestCase):
def setUp(self):