aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-05-05 07:53:11 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-05 07:53:11 -0700
commit5ab652cdb8bef10214edd079502a7f49017579aa (patch)
treec41047fb05c22525d383b758dac0304d53f982c1 /mllib
parentc6d1efba29a4235130024fee9f118e6b2cb89ce1 (diff)
downloadspark-5ab652cdb8bef10214edd079502a7f49017579aa.tar.gz
spark-5ab652cdb8bef10214edd079502a7f49017579aa.tar.bz2
spark-5ab652cdb8bef10214edd079502a7f49017579aa.zip
[SPARK-7202] [MLLIB] [PYSPARK] Add SparseMatrixPickler to SerDe
Utilities for pickling and unpickling SparseMatrices using SerDe Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5775 from MechCoder/spark-7202 and squashes the following commits: 7e689dc [MechCoder] [SPARK-7202] Add SparseMatrixPickler to SerDe
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala56
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala12
2 files changed, 67 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 6237b64c8f..8e9a208d61 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1015,6 +1015,61 @@ private[spark] object SerDe extends Serializable {
}
}
+ // Pickler for SparseMatrix
+ private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ val s = obj.asInstanceOf[SparseMatrix]
+ val order = ByteOrder.nativeOrder()
+
+ val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
+ val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
+ val valuesBytes = new Array[Byte](8 * s.values.length)
+ val isTransposed = if (s.isTransposed) 1 else 0
+ ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
+ ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
+ ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
+
+ out.write(Opcodes.MARK)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(s.numRows))
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(s.numCols))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
+ out.write(colPtrsBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
+ out.write(indicesBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
+ out.write(valuesBytes)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(isTransposed))
+ out.write(Opcodes.TUPLE)
+ }
+
+ def construct(args: Array[Object]): Object = {
+ if (args.length != 6) {
+ throw new PickleException("should be 6")
+ }
+ val order = ByteOrder.nativeOrder()
+ val colPtrsBytes = getBytes(args(2))
+ val indicesBytes = getBytes(args(3))
+ val valuesBytes = getBytes(args(4))
+ val colPtrs = new Array[Int](colPtrsBytes.length / 4)
+ val rowIndices = new Array[Int](indicesBytes.length / 4)
+ val values = new Array[Double](valuesBytes.length / 8)
+ ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
+ ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
+ ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
+ val isTransposed = args(5).asInstanceOf[Int] == 1
+ new SparseMatrix(
+ args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
+ isTransposed)
+ }
+ }
+
// Pickler for SparseVector
private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
@@ -1099,6 +1154,7 @@ private[spark] object SerDe extends Serializable {
if (!initialized) {
new DenseVectorPickler().register()
new DenseMatrixPickler().register()
+ new SparseMatrixPickler().register()
new SparseVectorPickler().register()
new LabeledPointPickler().register()
new RatingPickler().register()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index db8ed62fa4..a629dba8a4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python
import org.scalatest.FunSuite
-import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors}
+import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.recommendation.Rating
@@ -77,6 +77,16 @@ class PythonMLLibAPISuite extends FunSuite {
val emptyMatrix = Matrices.dense(0, 0, empty)
val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
assert(emptyMatrix == ne)
+
+ val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
+ val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix]
+ assert(sm.toArray === nsm.toArray)
+
+ val smt = new SparseMatrix(
+ 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
+ isTransposed=true)
+ val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
+ assert(smt.toArray === nsmt.toArray)
}
test("pickle rating") {