aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/_common.py48
1 files changed, 45 insertions, 3 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 43b491a971..8e3ad6b783 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -72,9 +72,9 @@ except:
# Python interpreter must agree on what endian the machine is.
-DENSE_VECTOR_MAGIC = 1
+DENSE_VECTOR_MAGIC = 1
SPARSE_VECTOR_MAGIC = 2
-DENSE_MATRIX_MAGIC = 3
+DENSE_MATRIX_MAGIC = 3
LABELED_POINT_MAGIC = 4
@@ -97,8 +97,28 @@ def _deserialize_numpy_array(shape, ba, offset, dtype=float64):
return ar.copy()
+def _serialize_double(d):
+ """
+ Serialize a double (float or numpy.float64) into a mutually understood format.
+ """
+ if type(d) == float or type(d) == float64:
+ d = float64(d)
+ ba = bytearray(8)
+ _copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64)
+ return ba
+ else:
+ raise TypeError("_serialize_double called on non-float input")
+
+
def _serialize_double_vector(v):
- """Serialize a double vector into a mutually understood format.
+ """
+ Serialize a double vector into a mutually understood format.
+
+ Note: we currently do not use a magic byte for double for storage
+ efficiency. This should be reconsidered when we add Ser/De for other
+ 8-byte types (e.g. Long), for safety. The corresponding deserializer,
+ _deserialize_double, needs to be modified as well if the serialization
+ scheme changes.
>>> x = array([1,2,3])
>>> y = _deserialize_double_vector(_serialize_double_vector(x))
@@ -148,6 +168,28 @@ def _serialize_sparse_vector(v):
return ba
+def _deserialize_double(ba, offset=0):
+ """Deserialize a double from a mutually understood format.
+
+ >>> import sys
+ >>> _deserialize_double(_serialize_double(123.0)) == 123.0
+ True
+ >>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0
+ True
+ >>> x = sys.float_info.max
+ >>> _deserialize_double(_serialize_double(sys.float_info.max)) == x
+ True
+ >>> y = float64(sys.float_info.max)
+ >>> _deserialize_double(_serialize_double(sys.float_info.max)) == y
+ True
+ """
+ if type(ba) != bytearray:
+ raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba))
+ if len(ba) - offset != 8:
+ raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb)
+ return struct.unpack("d", ba[offset:])[0]
+
+
def _deserialize_double_vector(ba, offset=0):
"""Deserialize a double vector from a mutually understood format.