aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHossein Falaki <falaki@gmail.com>2014-01-06 12:18:03 -0800
committerHossein Falaki <falaki@gmail.com>2014-01-06 12:18:03 -0800
commit11a93fb5a8fafa940db27b652e4c21f6713ed8d1 (patch)
treef5f7eba800acbd9f1d04bdf3ffe1cdcd0d3dfad5 /mllib
parent8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1 (diff)
downloadspark-11a93fb5a8fafa940db27b652e4c21f6713ed8d1.tar.gz
spark-11a93fb5a8fafa940db27b652e4c21f6713ed8d1.tar.bz2
spark-11a93fb5a8fafa940db27b652e4c21f6713ed8d1.zip
Added serializing method for Rating object
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala20
1 files changed, 16 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index be2628fac5..2d8623392e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -197,6 +197,7 @@ class PythonMLLibAPI extends Serializable {
return ret
}
+ /** Unpack a Rating object from an array of bytes */
private def unpackRating(ratingBytes: Array[Byte]): Rating = {
val bb = ByteBuffer.wrap(ratingBytes)
bb.order(ByteOrder.nativeOrder())
@@ -206,6 +207,7 @@ class PythonMLLibAPI extends Serializable {
return new Rating(user, product, rating)
}
+ /** Unpack a tuple of Ints from an array of bytes */
private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
val bb = ByteBuffer.wrap(tupleBytes)
bb.order(ByteOrder.nativeOrder())
@@ -214,13 +216,23 @@ class PythonMLLibAPI extends Serializable {
(v1, v2)
}
+ /**
+ * Serialize a Rating object into an array of bytes.
+ * It can be deserialized using RatingDeserializer().
+ *
+ * @param rate
+ * @return
+ */
private[spark] def serializeRating(rate: Rating): Array[Byte] = {
- val bytes = new Array[Byte](24)
+ val len = 3
+ val bytes = new Array[Byte](4 + 8 * len)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
- bb.putDouble(rate.user.toDouble)
- bb.putDouble(rate.product.toDouble)
- bb.putDouble(rate.rating)
+ bb.putInt(len)
+ val db = bb.asDoubleBuffer()
+ db.put(rate.user.toDouble)
+ db.put(rate.product.toDouble)
+ db.put(rate.rating)
bytes
}