aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/_common.py')
-rw-r--r--python/pyspark/mllib/_common.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index db341da85f..bb60d3d0c8 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -16,6 +16,7 @@
#
import struct
+import sys
import numpy
from numpy import ndarray, float64, int64, int32, array_equal, array
from pyspark import SparkContext, RDD
@@ -78,6 +79,14 @@ DENSE_MATRIX_MAGIC = 3
LABELED_POINT_MAGIC = 4
+# Workaround for SPARK-2954: before Python 2.7, struct.unpack couldn't unpack bytearray()s.
+if sys.version_info[:2] <= (2, 6):
+ def _unpack(fmt, string):
+ return struct.unpack(fmt, buffer(string))
+else:
+ _unpack = struct.unpack
+
+
def _deserialize_numpy_array(shape, ba, offset, dtype=float64):
"""
Deserialize a numpy array of the given type from an offset in
@@ -191,7 +200,7 @@ def _deserialize_double(ba, offset=0):
raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba))
if len(ba) - offset != 8:
raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb)
- return struct.unpack("d", ba[offset:])[0]
+ return _unpack("d", ba[offset:])[0]
def _deserialize_double_vector(ba, offset=0):