aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-05-18 12:55:13 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-05-18 12:55:13 -0700
commit32fbd297dd651ba3ce4ce52aeb0488233149cdf9 (patch)
treec18970db228448c856f8ce9538d708511e9e95fd
parent9dadf019b93038e1e18336ccd06c5eecb4bae32f (diff)
downloadspark-32fbd297dd651ba3ce4ce52aeb0488233149cdf9.tar.gz
spark-32fbd297dd651ba3ce4ce52aeb0488233149cdf9.tar.bz2
spark-32fbd297dd651ba3ce4ce52aeb0488233149cdf9.zip
[SPARK-6216] [PYSPARK] check python version of worker with driver
This PR revert #5404, change to pass the version of python in driver into JVM, check it in worker before deserializing closure, then it can works with different major version of Python. Author: Davies Liu <davies@databricks.com> Closes #6203 from davies/py_version and squashes the following commits: b8fb76e [Davies Liu] fix test 6ce5096 [Davies Liu] use string for version 47c6278 [Davies Liu] check python version of worker with driver
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala3
-rw-r--r--python/pyspark/context.py1
-rw-r--r--python/pyspark/rdd.py4
-rw-r--r--python/pyspark/sql/context.py1
-rw-r--r--python/pyspark/sql/functions.py4
-rw-r--r--python/pyspark/tests.py6
-rw-r--r--python/pyspark/worker.py12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala2
10 files changed, 26 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 7409dc2d86..2d92f6a42b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -47,6 +47,7 @@ private[spark] class PythonRDD(
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
+ pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
@@ -210,6 +211,8 @@ private[spark] class PythonRDD(
val dataOut = new DataOutputStream(stream)
// Partition index
dataOut.writeInt(split.index)
+ // Python version of driver
+ PythonRDD.writeUTF(pythonVer, dataOut)
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Python includes (*.zip and *.egg files)
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 31992795a9..d25ee85523 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -173,6 +173,7 @@ class SparkContext(object):
self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
+ self.pythonVer = "%d.%d" % sys.version_info[:2]
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 545c5ad20c..70db4bbe4c 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2260,7 +2260,7 @@ class RDD(object):
def _prepare_for_python_RDD(sc, command, obj=None):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
- pickled_command = ser.dumps((command, sys.version_info[:2]))
+ pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
# The broadcast will have same life cycle as created PythonRDD
broadcast = sc.broadcast(pickled_command)
@@ -2344,7 +2344,7 @@ class PipelinedRDD(RDD):
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
- self.ctx.pythonExec,
+ self.ctx.pythonExec, self.ctx.pythonVer,
bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index f6f107ca32..0bde719124 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -157,6 +157,7 @@ class SQLContext(object):
env,
includes,
self._sc.pythonExec,
+ self._sc.pythonVer,
bvars,
self._sc._javaAccumulator,
returnType.json())
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8d0e766ecd..fbe9bf5b52 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -353,8 +353,8 @@ class UserDefinedFunction(object):
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
- judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
- includes, sc.pythonExec, broadcast_vars,
+ judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
+ sc.pythonExec, sc.pythonVer, broadcast_vars,
sc._javaAccumulator, jdt)
return judf
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 09de4d159f..5e023f6c53 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1543,13 +1543,13 @@ class WorkerTests(ReusedPySparkTestCase):
def test_with_different_versions_of_python(self):
rdd = self.sc.parallelize(range(10))
rdd.count()
- version = sys.version_info
- sys.version_info = (2, 0, 0)
+ version = self.sc.pythonVer
+ self.sc.pythonVer = "2.0"
try:
with QuietTest(self.sc):
self.assertRaises(Py4JJavaError, lambda: rdd.count())
finally:
- sys.version_info = version
+ self.sc.pythonVer = version
class SparkSubmitTests(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index fbdaf3a581..93df9002be 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -57,6 +57,12 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
exit(-1)
+ version = utf8_deserializer.loads(infile)
+ if version != "%d.%d" % sys.version_info[:2]:
+ raise Exception(("Python in worker has different version %s than that in " +
+ "driver %s, PySpark cannot run with different minor versions") %
+ ("%d.%d" % sys.version_info[:2], version))
+
# initialize global state
shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
@@ -92,11 +98,7 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
- (func, profiler, deserializer, serializer), version = command
- if version != sys.version_info[:2]:
- raise Exception(("Python in worker has different version %s than that in " +
- "driver %s, PySpark cannot run with different minor versions") %
- (sys.version_info[:2], version))
+ func, profiler, deserializer, serializer = command
init_time = time.time()
def process():
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index dc3389c41b..3cc5c2441d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -46,6 +46,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
+ pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
stringDataType: String): Unit = {
@@ -70,6 +71,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
envVars,
pythonIncludes,
pythonExec,
+ pythonVer,
broadcastVars,
accumulator,
dataType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index 505ab1301e..a02e202d2e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -58,14 +58,15 @@ private[sql] case class UserDefinedPythonFunction(
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
+ pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType) {
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
def apply(exprs: Column*): Column = {
- val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
- accumulator, dataType, exprs.map(_.expr))
+ val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
+ broadcastVars, accumulator, dataType, exprs.map(_.expr))
Column(udf)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 65dd7ba020..11b2897f76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -46,6 +46,7 @@ private[spark] case class PythonUDF(
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
+ pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType,
@@ -251,6 +252,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
udf.pythonIncludes,
false,
udf.pythonExec,
+ udf.pythonVer,
udf.broadcastVars,
udf.accumulator
).mapPartitions { iter =>