aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-04-21 14:36:50 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-21 14:36:50 -0700
commit45c47fa4176ea75886a58f5d73c44afcb29aa629 (patch)
tree68601b1683fe06ababd86884d5a92d406097c553 /mllib/src
parentc25ca7c5a1f2a4f88f40b0c5cdbfa927c186cfa8 (diff)
downloadspark-45c47fa4176ea75886a58f5d73c44afcb29aa629.tar.gz
spark-45c47fa4176ea75886a58f5d73c44afcb29aa629.tar.bz2
spark-45c47fa4176ea75886a58f5d73c44afcb29aa629.zip
[SPARK-6845] [MLlib] [PySpark] Add isTranposed flag to DenseMatrix
Since sparse matrices now support a isTransposed flag for row major data, DenseMatrices should do the same. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5455 from MechCoder/spark-6845 and squashes the following commits: 525c370 [MechCoder] minor 004a37f [MechCoder] Cast boolean to int 151f3b6 [MechCoder] [WIP] Add isTransposed to pickle DenseMatrix cc0b90a [MechCoder] [SPARK-6845] Add isTranposed flag to DenseMatrix
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala13
1 files changed, 9 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f976d2f97b..6237b64c8f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -985,8 +985,10 @@ private[spark] object SerDe extends Serializable {
val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
val bytes = new Array[Byte](8 * m.values.size)
val order = ByteOrder.nativeOrder()
+ val isTransposed = if (m.isTransposed) 1 else 0
ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
+ out.write(Opcodes.MARK)
out.write(Opcodes.BININT)
out.write(PickleUtils.integer_to_bytes(m.numRows))
out.write(Opcodes.BININT)
@@ -994,19 +996,22 @@ private[spark] object SerDe extends Serializable {
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(bytes.length))
out.write(bytes)
- out.write(Opcodes.TUPLE3)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(isTransposed))
+ out.write(Opcodes.TUPLE)
}
def construct(args: Array[Object]): Object = {
- if (args.length != 3) {
- throw new PickleException("should be 3")
+ if (args.length != 4) {
+ throw new PickleException("should be 4")
}
val bytes = getBytes(args(2))
val n = bytes.length / 8
val values = new Array[Double](n)
val order = ByteOrder.nativeOrder()
ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
- new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values)
+ val isTransposed = args(3).asInstanceOf[Int] == 1
+ new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
}
}