aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala13
1 files changed, 9 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f976d2f97b..6237b64c8f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -985,8 +985,10 @@ private[spark] object SerDe extends Serializable {
val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
val bytes = new Array[Byte](8 * m.values.size)
val order = ByteOrder.nativeOrder()
+ val isTransposed = if (m.isTransposed) 1 else 0
ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
+ out.write(Opcodes.MARK)
out.write(Opcodes.BININT)
out.write(PickleUtils.integer_to_bytes(m.numRows))
out.write(Opcodes.BININT)
@@ -994,19 +996,22 @@ private[spark] object SerDe extends Serializable {
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(bytes.length))
out.write(bytes)
- out.write(Opcodes.TUPLE3)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(isTransposed))
+ out.write(Opcodes.TUPLE)
}
def construct(args: Array[Object]): Object = {
- if (args.length != 3) {
- throw new PickleException("should be 3")
+ if (args.length != 4) {
+ throw new PickleException("should be 4")
}
val bytes = getBytes(args(2))
val n = bytes.length / 8
val values = new Array[Double](n)
val order = ByteOrder.nativeOrder()
ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
- new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values)
+ val isTransposed = args(3).asInstanceOf[Int] == 1
+ new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
}
}