aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/_common.py')
-rw-r--r--python/pyspark/mllib/_common.py72
1 files changed, 51 insertions, 21 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 802a27a8da..a411a5d591 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -22,6 +22,7 @@ from pyspark import SparkContext, RDD
from pyspark.mllib.linalg import SparseVector
from pyspark.serializers import Serializer
+
"""
Common utilities shared throughout MLlib, primarily for dealing with
different data types. These include:
@@ -147,7 +148,7 @@ def _serialize_sparse_vector(v):
return ba
-def _deserialize_double_vector(ba):
+def _deserialize_double_vector(ba, offset=0):
"""Deserialize a double vector from a mutually understood format.
>>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0])
@@ -160,43 +161,46 @@ def _deserialize_double_vector(ba):
if type(ba) != bytearray:
raise TypeError("_deserialize_double_vector called on a %s; "
"wanted bytearray" % type(ba))
- if len(ba) < 5:
+ nb = len(ba) - offset
+ if nb < 5:
raise TypeError("_deserialize_double_vector called on a %d-byte array, "
- "which is too short" % len(ba))
- if ba[0] == DENSE_VECTOR_MAGIC:
- return _deserialize_dense_vector(ba)
- elif ba[0] == SPARSE_VECTOR_MAGIC:
- return _deserialize_sparse_vector(ba)
+ "which is too short" % nb)
+ if ba[offset] == DENSE_VECTOR_MAGIC:
+ return _deserialize_dense_vector(ba, offset)
+ elif ba[offset] == SPARSE_VECTOR_MAGIC:
+ return _deserialize_sparse_vector(ba, offset)
else:
raise TypeError("_deserialize_double_vector called on bytearray "
"with wrong magic")
-def _deserialize_dense_vector(ba):
+def _deserialize_dense_vector(ba, offset=0):
"""Deserialize a dense vector into a numpy array."""
- if len(ba) < 5:
+ nb = len(ba) - offset
+ if nb < 5:
raise TypeError("_deserialize_dense_vector called on a %d-byte array, "
- "which is too short" % len(ba))
- length = ndarray(shape=[1], buffer=ba, offset=1, dtype=int32)[0]
- if len(ba) != 8 * length + 5:
+ "which is too short" % nb)
+ length = ndarray(shape=[1], buffer=ba, offset=offset + 1, dtype=int32)[0]
+ if nb < 8 * length + 5:
raise TypeError("_deserialize_dense_vector called on bytearray "
"with wrong length")
- return _deserialize_numpy_array([length], ba, 5)
+ return _deserialize_numpy_array([length], ba, offset + 5)
-def _deserialize_sparse_vector(ba):
+def _deserialize_sparse_vector(ba, offset=0):
"""Deserialize a sparse vector into a MLlib SparseVector object."""
- if len(ba) < 9:
+ nb = len(ba) - offset
+ if nb < 9:
raise TypeError("_deserialize_sparse_vector called on a %d-byte array, "
- "which is too short" % len(ba))
- header = ndarray(shape=[2], buffer=ba, offset=1, dtype=int32)
+ "which is too short" % nb)
+ header = ndarray(shape=[2], buffer=ba, offset=offset + 1, dtype=int32)
size = header[0]
nonzeros = header[1]
- if len(ba) != 9 + 12 * nonzeros:
+ if nb < 9 + 12 * nonzeros:
raise TypeError("_deserialize_sparse_vector called on bytearray "
"with wrong length")
- indices = _deserialize_numpy_array([nonzeros], ba, 9, dtype=int32)
- values = _deserialize_numpy_array([nonzeros], ba, 9 + 4 * nonzeros, dtype=float64)
+ indices = _deserialize_numpy_array([nonzeros], ba, offset + 9, dtype=int32)
+ values = _deserialize_numpy_array([nonzeros], ba, offset + 9 + 4 * nonzeros, dtype=float64)
return SparseVector(int(size), indices, values)
@@ -243,7 +247,23 @@ def _deserialize_double_matrix(ba):
def _serialize_labeled_point(p):
- """Serialize a LabeledPoint with a features vector of any type."""
+ """
+ Serialize a LabeledPoint with a features vector of any type.
+
+ >>> from pyspark.mllib.regression import LabeledPoint
+ >>> dp0 = LabeledPoint(0.5, array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0]))
+ >>> dp1 = _deserialize_labeled_point(_serialize_labeled_point(dp0))
+ >>> dp1.label == dp0.label
+ True
+ >>> array_equal(dp1.features, dp0.features)
+ True
+ >>> sp0 = LabeledPoint(0.0, SparseVector(4, [1, 3], [3.0, 5.5]))
+ >>> sp1 = _deserialize_labeled_point(_serialize_labeled_point(sp0))
+ >>> sp1.label == sp1.label
+ True
+ >>> sp1.features == sp0.features
+ True
+ """
from pyspark.mllib.regression import LabeledPoint
serialized_features = _serialize_double_vector(p.features)
header = bytearray(9)
@@ -252,6 +272,16 @@ def _serialize_labeled_point(p):
header_float[0] = p.label
return header + serialized_features
+def _deserialize_labeled_point(ba, offset=0):
+ """Deserialize a LabeledPoint from a mutually understood format."""
+ from pyspark.mllib.regression import LabeledPoint
+ if type(ba) != bytearray:
+ raise TypeError("Expecting a bytearray but got %s" % type(ba))
+ if ba[offset] != LABELED_POINT_MAGIC:
+ raise TypeError("Expecting magic number %d but got %d" % (LABELED_POINT_MAGIC, ba[0]))
+ label = ndarray(shape=[1], buffer=ba, offset=offset + 1, dtype=float64)[0]
+ features = _deserialize_double_vector(ba, offset + 9)
+ return LabeledPoint(label, features)
def _copyto(array, buffer, offset, shape, dtype):
"""