aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala23
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala8
2 files changed, 31 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index c44173793b..954621ee8b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -54,6 +54,13 @@ class PythonMLLibAPI extends Serializable {
}
}
+ private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
+ require(bytes.length - offset == 8, "Wrong size byte array for Double")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ bb.getDouble
+ }
+
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short")
@@ -89,6 +96,22 @@ class PythonMLLibAPI extends Serializable {
Vectors.sparse(size, indices, values)
}
+ /**
+ * Returns an 8-byte array for the input Double.
+ *
+ * Note: we currently do not use a magic byte for double for storage efficiency.
+ * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
+ * The corresponding deserializer, deserializeDouble, needs to be modified as well if the
+ * serialization scheme changes.
+ */
+ private[python] def serializeDouble(double: Double): Array[Byte] = {
+ val bytes = new Array[Byte](8)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putDouble(double)
+ bytes
+ }
+
private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length
val bytes = new Array[Byte](5 + 8 * len)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index 642843f902..d94cfa2fce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -57,4 +57,12 @@ class PythonMLLibAPISuite extends FunSuite {
assert(q.features === p.features)
}
}
+
+ test("double serialization") {
+ for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) {
+ val bytes = py.serializeDouble(x)
+ val deser = py.deserializeDouble(bytes)
+ assert(x === deser)
+ }
+ }
}