From 32fbd297dd651ba3ce4ce52aeb0488233149cdf9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 18 May 2015 12:55:13 -0700 Subject: [SPARK-6216] [PYSPARK] check python version of worker with driver This PR revert #5404, change to pass the version of python in driver into JVM, check it in worker before deserializing closure, then it can works with different major version of Python. Author: Davies Liu Closes #6203 from davies/py_version and squashes the following commits: b8fb76e [Davies Liu] fix test 6ce5096 [Davies Liu] use string for version 47c6278 [Davies Liu] check python version of worker with driver --- .../main/scala/org/apache/spark/api/python/PythonRDD.scala | 3 +++ python/pyspark/context.py | 1 + python/pyspark/rdd.py | 4 ++-- python/pyspark/sql/context.py | 1 + python/pyspark/sql/functions.py | 4 ++-- python/pyspark/tests.py | 6 +++--- python/pyspark/worker.py | 12 +++++++----- .../main/scala/org/apache/spark/sql/UDFRegistration.scala | 2 ++ .../scala/org/apache/spark/sql/UserDefinedFunction.scala | 5 +++-- .../scala/org/apache/spark/sql/execution/pythonUdfs.scala | 2 ++ 10 files changed, 26 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 7409dc2d86..2d92f6a42b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -47,6 +47,7 @@ private[spark] class PythonRDD( pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { @@ -210,6 +211,8 @@ private[spark] class PythonRDD( val dataOut = new DataOutputStream(stream) // Partition index dataOut.writeInt(split.index) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Python includes (*.zip and *.egg files) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 31992795a9..d25ee85523 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -173,6 +173,7 @@ class SparkContext(object): self._jvm.PythonAccumulatorParam(host, port)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') + self.pythonVer = "%d.%d" % sys.version_info[:2] # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 545c5ad20c..70db4bbe4c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2260,7 +2260,7 @@ class RDD(object): def _prepare_for_python_RDD(sc, command, obj=None): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() - pickled_command = ser.dumps((command, sys.version_info[:2])) + pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) @@ -2344,7 +2344,7 @@ class PipelinedRDD(RDD): python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), bytearray(pickled_cmd), env, includes, self.preservesPartitioning, - self.ctx.pythonExec, + self.ctx.pythonExec, self.ctx.pythonVer, bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index f6f107ca32..0bde719124 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -157,6 +157,7 @@ class SQLContext(object): env, includes, self._sc.pythonExec, + self._sc.pythonVer, bvars, self._sc._javaAccumulator, returnType.json()) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8d0e766ecd..fbe9bf5b52 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -353,8 +353,8 @@ class UserDefinedFunction(object): ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ - judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, - includes, sc.pythonExec, broadcast_vars, + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, + sc.pythonExec, sc.pythonVer, broadcast_vars, sc._javaAccumulator, jdt) return judf diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 09de4d159f..5e023f6c53 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1543,13 +1543,13 @@ class WorkerTests(ReusedPySparkTestCase): def test_with_different_versions_of_python(self): rdd = self.sc.parallelize(range(10)) rdd.count() - version = sys.version_info - sys.version_info = (2, 0, 0) + version = self.sc.pythonVer + self.sc.pythonVer = "2.0" try: with QuietTest(self.sc): self.assertRaises(Py4JJavaError, lambda: rdd.count()) finally: - sys.version_info = version + self.sc.pythonVer = version class SparkSubmitTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fbdaf3a581..93df9002be 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -57,6 +57,12 @@ def main(infile, outfile): if split_index == -1: # for unit tests exit(-1) + version = utf8_deserializer.loads(infile) + if version != "%d.%d" % sys.version_info[:2]: + raise Exception(("Python in worker has different version %s than that in " + + "driver %s, PySpark cannot run with different minor versions") % + ("%d.%d" % sys.version_info[:2], version)) + # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 @@ -92,11 +98,7 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, profiler, deserializer, serializer), version = command - if version != sys.version_info[:2]: - raise Exception(("Python in worker has different version %s than that in " + - "driver %s, PySpark cannot run with different minor versions") % - (sys.version_info[:2], version)) + func, profiler, deserializer, serializer = command init_time = time.time() def process(): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc3389c41b..3cc5c2441d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -46,6 +46,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { @@ -70,6 +71,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { envVars, pythonIncludes, pythonExec, + pythonVer, broadcastVars, accumulator, dataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 505ab1301e..a02e202d2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -58,14 +58,15 @@ private[sql] case class UserDefinedPythonFunction( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType) { /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ def apply(exprs: Column*): Column = { - val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars, - accumulator, dataType, exprs.map(_.expr)) + val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, + broadcastVars, accumulator, dataType, exprs.map(_.expr)) Column(udf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 65dd7ba020..11b2897f76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -46,6 +46,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, @@ -251,6 +252,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: udf.pythonIncludes, false, udf.pythonExec, + udf.pythonVer, udf.broadcastVars, udf.accumulator ).mapPartitions { iter => -- cgit v1.2.3