aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala23
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala8
-rw-r--r--python/pyspark/mllib/_common.py48
3 files changed, 76 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index c44173793b..954621ee8b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -54,6 +54,13 @@ class PythonMLLibAPI extends Serializable {
}
}
+ private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
+ require(bytes.length - offset == 8, "Wrong size byte array for Double")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ bb.getDouble
+ }
+
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short")
@@ -89,6 +96,22 @@ class PythonMLLibAPI extends Serializable {
Vectors.sparse(size, indices, values)
}
+ /**
+ * Returns an 8-byte array for the input Double.
+ *
+ * Note: we currently do not use a magic byte for double for storage efficiency.
+ * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
+ * The corresponding deserializer, deserializeDouble, needs to be modified as well if the
+ * serialization scheme changes.
+ */
+ private[python] def serializeDouble(double: Double): Array[Byte] = {
+ val bytes = new Array[Byte](8)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putDouble(double)
+ bytes
+ }
+
private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length
val bytes = new Array[Byte](5 + 8 * len)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index 642843f902..d94cfa2fce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -57,4 +57,12 @@ class PythonMLLibAPISuite extends FunSuite {
assert(q.features === p.features)
}
}
+
+ test("double serialization") {
+ for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) {
+ val bytes = py.serializeDouble(x)
+ val deser = py.deserializeDouble(bytes)
+ assert(x === deser)
+ }
+ }
}
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 43b491a971..8e3ad6b783 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -72,9 +72,9 @@ except:
# Python interpreter must agree on what endian the machine is.
-DENSE_VECTOR_MAGIC = 1
+DENSE_VECTOR_MAGIC = 1
SPARSE_VECTOR_MAGIC = 2
-DENSE_MATRIX_MAGIC = 3
+DENSE_MATRIX_MAGIC = 3
LABELED_POINT_MAGIC = 4
@@ -97,8 +97,28 @@ def _deserialize_numpy_array(shape, ba, offset, dtype=float64):
return ar.copy()
+def _serialize_double(d):
+ """
+ Serialize a double (float or numpy.float64) into a mutually understood format.
+ """
+ if type(d) == float or type(d) == float64:
+ d = float64(d)
+ ba = bytearray(8)
+ _copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64)
+ return ba
+ else:
+ raise TypeError("_serialize_double called on non-float input")
+
+
def _serialize_double_vector(v):
- """Serialize a double vector into a mutually understood format.
+ """
+ Serialize a double vector into a mutually understood format.
+
+ Note: we currently do not use a magic byte for double for storage
+ efficiency. This should be reconsidered when we add Ser/De for other
+ 8-byte types (e.g. Long), for safety. The corresponding deserializer,
+ _deserialize_double, needs to be modified as well if the serialization
+ scheme changes.
>>> x = array([1,2,3])
>>> y = _deserialize_double_vector(_serialize_double_vector(x))
@@ -148,6 +168,28 @@ def _serialize_sparse_vector(v):
return ba
+def _deserialize_double(ba, offset=0):
+ """Deserialize a double from a mutually understood format.
+
+ >>> import sys
+ >>> _deserialize_double(_serialize_double(123.0)) == 123.0
+ True
+ >>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0
+ True
+ >>> x = sys.float_info.max
+ >>> _deserialize_double(_serialize_double(sys.float_info.max)) == x
+ True
+ >>> y = float64(sys.float_info.max)
+ >>> _deserialize_double(_serialize_double(sys.float_info.max)) == y
+ True
+ """
+ if type(ba) != bytearray:
+ raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba))
+ if len(ba) - offset != 8:
+ raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb)
+ return struct.unpack("d", ba[offset:])[0]
+
+
def _deserialize_double_vector(ba, offset=0):
"""Deserialize a double vector from a mutually understood format.