aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-04-10 14:04:53 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-04-10 14:04:53 -0700
commit4740d6a158cb4d35408a95265c5b950b9e9628a3 (patch)
tree82ddbf26be5bc78e074c5b4034eaa3e7f0655f59 /python
parent0375134f42197f2e29aa865a513cda381f0a1445 (diff)
downloadspark-4740d6a158cb4d35408a95265c5b950b9e9628a3.tar.gz
spark-4740d6a158cb4d35408a95265c5b950b9e9628a3.tar.bz2
spark-4740d6a158cb4d35408a95265c5b950b9e9628a3.zip
[SPARK-6216] [PySpark] check the python version in worker
Author: Davies Liu <davies@databricks.com> Closes #5404 from davies/check_version and squashes the following commits: e559248 [Davies Liu] add tests ec33b5f [Davies Liu] check the python version in worker
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py2
-rw-r--r--python/pyspark/tests.py16
-rw-r--r--python/pyspark/worker.py6
3 files changed, 22 insertions, 2 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c8e54ed5c6..c9ac95d117 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2233,7 +2233,7 @@ class RDD(object):
def _prepare_for_python_RDD(sc, command, obj=None):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
- pickled_command = ser.dumps(command)
+ pickled_command = ser.dumps((command, sys.version_info[:2]))
if len(pickled_command) > (1 << 20): # 1M
broadcast = sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 0e3721b55a..b938b9ce12 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -35,6 +35,8 @@ import itertools
import threading
import hashlib
+from py4j.protocol import Py4JJavaError
+
if sys.version_info[:2] <= (2, 6):
try:
import unittest2 as unittest
@@ -1494,6 +1496,20 @@ class WorkerTests(PySparkTestCase):
self.assertTrue(not t.isAlive())
self.assertEqual(100000, rdd.count())
+ def test_with_different_versions_of_python(self):
+ rdd = self.sc.parallelize(range(10))
+ rdd.count()
+ version = sys.version_info
+ sys.version_info = (2, 0, 0)
+ log4j = self.sc._jvm.org.apache.log4j
+ old_level = log4j.LogManager.getRootLogger().getLevel()
+ log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL)
+ try:
+ self.assertRaises(Py4JJavaError, lambda: rdd.count())
+ finally:
+ sys.version_info = version
+ log4j.LogManager.getRootLogger().setLevel(old_level)
+
class SparkSubmitTests(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 8a93c320ec..452d6fabdc 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -88,7 +88,11 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
- (func, profiler, deserializer, serializer) = command
+ (func, profiler, deserializer, serializer), version = command
+ if version != sys.version_info[:2]:
+ raise Exception(("Python in worker has different version %s than that in " +
+ "driver %s, PySpark cannot run with different minor versions") %
+ (sys.version_info[:2], version))
init_time = time.time()
def process():