aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorDoris Xin <doris.s.xin@gmail.com>2014-07-27 07:21:07 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-27 07:21:07 -0700
commit3a69c72e5cbe270b76f6ab6a84a2e334e87cce8c (patch)
treeaf1d111d1e51099ca915a13ed786ec26752df147 /mllib/src
parentaaf2b735fddbebccd28012006ee4647af3b3624f (diff)
downloadspark-3a69c72e5cbe270b76f6ab6a84a2e334e87cce8c.tar.gz
spark-3a69c72e5cbe270b76f6ab6a84a2e334e87cce8c.tar.bz2
spark-3a69c72e5cbe270b76f6ab6a84a2e334e87cce8c.zip
[SPARK-2679] [MLLib] Ser/De for Double
Added a set of serializer/deserializer for Double in _common.py and PythonMLLibAPI in MLLib. Author: Doris Xin <doris.s.xin@gmail.com> Closes #1581 from dorx/doubleSerDe and squashes the following commits: 86a85b3 [Doris Xin] Merge branch 'master' into doubleSerDe 2bfe7a4 [Doris Xin] Removed magic byte ad4d0d9 [Doris Xin] removed a space in unit a9020bc [Doris Xin] units passed 7dad9af [Doris Xin] WIP
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala23
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala8
2 files changed, 31 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index c44173793b..954621ee8b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -54,6 +54,13 @@ class PythonMLLibAPI extends Serializable {
}
}
+ private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
+ require(bytes.length - offset == 8, "Wrong size byte array for Double")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ bb.getDouble
+ }
+
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short")
@@ -89,6 +96,22 @@ class PythonMLLibAPI extends Serializable {
Vectors.sparse(size, indices, values)
}
+ /**
+ * Returns an 8-byte array for the input Double.
+ *
+ * Note: we currently do not use a magic byte for double for storage efficiency.
+ * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
+ * The corresponding deserializer, deserializeDouble, needs to be modified as well if the
+ * serialization scheme changes.
+ */
+ private[python] def serializeDouble(double: Double): Array[Byte] = {
+ val bytes = new Array[Byte](8)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putDouble(double)
+ bytes
+ }
+
private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length
val bytes = new Array[Byte](5 + 8 * len)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index 642843f902..d94cfa2fce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -57,4 +57,12 @@ class PythonMLLibAPISuite extends FunSuite {
assert(q.features === p.features)
}
}
+
+ test("double serialization") {
+ for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) {
+ val bytes = py.serializeDouble(x)
+ val deser = py.deserializeDouble(bytes)
+ assert(x === deser)
+ }
+ }
}