aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/java_gateway.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/java_gateway.py')
-rw-r--r--python/pyspark/java_gateway.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index e615c1e9b6..c15add5237 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -31,7 +31,7 @@ def launch_gateway():
# Launch the Py4j gateway using Spark's run command so that we pick up the
# proper classpath and SPARK_MEM settings from spark-env.sh
on_windows = platform.system() == "Windows"
- script = "spark-class.cmd" if on_windows else "spark-class"
+ script = "./bin/spark-class.cmd" if on_windows else "./bin/spark-class"
command = [os.path.join(SPARK_HOME, script), "py4j.GatewayServer",
"--die-on-broken-pipe", "0"]
if not on_windows:
@@ -60,7 +60,9 @@ def launch_gateway():
# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
# Import the classes used by PySpark
+ java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
+ java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
java_import(gateway.jvm, "scala.Tuple2")
return gateway