aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/.gitignore3
-rw-r--r--python/lib/PY4J_VERSION.txt1
-rw-r--r--python/pyspark/__init__.py7
-rw-r--r--python/pyspark/java_gateway.py29
-rw-r--r--python/pyspark/tests.py4
5 files changed, 33 insertions, 11 deletions
diff --git a/python/.gitignore b/python/.gitignore
index 5c56e638f9..80b361ffbd 100644
--- a/python/.gitignore
+++ b/python/.gitignore
@@ -1,2 +1,5 @@
*.pyc
docs/
+pyspark.egg-info
+build/
+dist/
diff --git a/python/lib/PY4J_VERSION.txt b/python/lib/PY4J_VERSION.txt
deleted file mode 100644
index 04a0cd52a8..0000000000
--- a/python/lib/PY4J_VERSION.txt
+++ /dev/null
@@ -1 +0,0 @@
-b7924aabe9c5e63f0a4d8bbd17019534c7ec014e
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 73fe7378ff..07df8697bd 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -49,13 +49,6 @@ Hive:
Main entry point for accessing data stored in Apache Hive..
"""
-
-
-import sys
-import os
-sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j-0.8.1-src.zip"))
-
-
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 6bb6c877c9..032d960e40 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -24,10 +24,11 @@ from threading import Thread
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
-SPARK_HOME = os.environ["SPARK_HOME"]
+def launch_gateway():
+ SPARK_HOME = os.environ["SPARK_HOME"]
+ set_env_vars_for_yarn()
-def launch_gateway():
# Launch the Py4j gateway using Spark's run command so that we pick up the
# proper classpath and settings from spark-env.sh
on_windows = platform.system() == "Windows"
@@ -70,3 +71,27 @@ def launch_gateway():
java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
+
+def set_env_vars_for_yarn():
+ # Add the spark jar, which includes the pyspark files, to the python path
+ env_map = parse_env(os.environ.get("SPARK_YARN_USER_ENV", ""))
+ if "PYTHONPATH" in env_map:
+ env_map["PYTHONPATH"] += ":spark.jar"
+ else:
+ env_map["PYTHONPATH"] = "spark.jar"
+
+ os.environ["SPARK_YARN_USER_ENV"] = ",".join(k + '=' + v for (k, v) in env_map.items())
+
+def parse_env(env_str):
+ # Turns a comma-separated of env settings into a dict that maps env vars to
+ # their values.
+ env = {}
+ for var_str in env_str.split(","):
+ parts = var_str.split("=")
+ if len(parts) == 2:
+ env[parts[0]] = parts[1]
+ elif len(var_str) > 0:
+ print "Invalid entry in SPARK_YARN_USER_ENV: " + var_str
+ sys.exit(1)
+
+ return env
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 527104587f..8cf9d9cf1b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -30,10 +30,12 @@ import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
-from pyspark.java_gateway import SPARK_HOME
from pyspark.serializers import read_int
+SPARK_HOME = os.environ["SPARK_HOME"]
+
+
class PySparkTestCase(unittest.TestCase):
def setUp(self):