aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMike Dusenberry <mwdusenb@us.ibm.com>2015-10-27 11:05:14 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-27 11:05:14 -0700
commit3bdbbc6c972567861044dd6a6dc82f35cd12442d (patch)
treeb49f4ac45d3ecbf2258b0aecb960d2e6d8b19ee1 /python
parent9fc16a82adb5f3db2a250765c11393794404a51b (diff)
downloadspark-3bdbbc6c972567861044dd6a6dc82f35cd12442d.tar.gz
spark-3bdbbc6c972567861044dd6a6dc82f35cd12442d.tar.bz2
spark-3bdbbc6c972567861044dd6a6dc82f35cd12442d.zip
[SPARK-6488][MLLIB][PYTHON] Support addition/multiplication in PySpark's BlockMatrix
This PR adds addition and multiplication to PySpark's `BlockMatrix` class via `add` and `multiply` functions. Author: Mike Dusenberry <mwdusenb@us.ibm.com> Closes #9139 from dusenberrymw/SPARK-6488_Add_Addition_and_Multiplication_to_PySpark_BlockMatrix.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/linalg/distributed.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py
index aec407de90..0e76050788 100644
--- a/python/pyspark/mllib/linalg/distributed.py
+++ b/python/pyspark/mllib/linalg/distributed.py
@@ -775,6 +775,74 @@ class BlockMatrix(DistributedMatrix):
"""
return self._java_matrix_wrapper.call("numCols")
+ def add(self, other):
+ """
+ Adds two block matrices together. The matrices must have the
+ same size and matching `rowsPerBlock` and `colsPerBlock` values.
+ If one of the sub matrix blocks that are being added is a
+ SparseMatrix, the resulting sub matrix block will also be a
+ SparseMatrix, even if it is being added to a DenseMatrix. If
+ two dense sub matrix blocks are added, the output block will
+ also be a DenseMatrix.
+
+ >>> dm1 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])
+ >>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12])
+ >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12])
+ >>> blocks1 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)])
+ >>> blocks2 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)])
+ >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm2)])
+ >>> mat1 = BlockMatrix(blocks1, 3, 2)
+ >>> mat2 = BlockMatrix(blocks2, 3, 2)
+ >>> mat3 = BlockMatrix(blocks3, 3, 2)
+
+ >>> mat1.add(mat2).toLocalMatrix()
+ DenseMatrix(6, 2, [2.0, 4.0, 6.0, 14.0, 16.0, 18.0, 8.0, 10.0, 12.0, 20.0, 22.0, 24.0], 0)
+
+ >>> mat1.add(mat3).toLocalMatrix()
+ DenseMatrix(6, 2, [8.0, 2.0, 3.0, 14.0, 16.0, 18.0, 4.0, 16.0, 18.0, 20.0, 22.0, 24.0], 0)
+ """
+ if not isinstance(other, BlockMatrix):
+ raise TypeError("Other should be a BlockMatrix, got %s" % type(other))
+
+ other_java_block_matrix = other._java_matrix_wrapper._java_model
+ java_block_matrix = self._java_matrix_wrapper.call("add", other_java_block_matrix)
+ return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock)
+
+ def multiply(self, other):
+ """
+ Left multiplies this BlockMatrix by `other`, another
+ BlockMatrix. The `colsPerBlock` of this matrix must equal the
+ `rowsPerBlock` of `other`. If `other` contains any SparseMatrix
+ blocks, they will have to be converted to DenseMatrix blocks.
+ The output BlockMatrix will only consist of DenseMatrix blocks.
+ This may cause some performance issues until support for
+ multiplying two sparse matrices is added.
+
+ >>> dm1 = Matrices.dense(2, 3, [1, 2, 3, 4, 5, 6])
+ >>> dm2 = Matrices.dense(2, 3, [7, 8, 9, 10, 11, 12])
+ >>> dm3 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])
+ >>> dm4 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12])
+ >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12])
+ >>> blocks1 = sc.parallelize([((0, 0), dm1), ((0, 1), dm2)])
+ >>> blocks2 = sc.parallelize([((0, 0), dm3), ((1, 0), dm4)])
+ >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm4)])
+ >>> mat1 = BlockMatrix(blocks1, 2, 3)
+ >>> mat2 = BlockMatrix(blocks2, 3, 2)
+ >>> mat3 = BlockMatrix(blocks3, 3, 2)
+
+ >>> mat1.multiply(mat2).toLocalMatrix()
+ DenseMatrix(2, 2, [242.0, 272.0, 350.0, 398.0], 0)
+
+ >>> mat1.multiply(mat3).toLocalMatrix()
+ DenseMatrix(2, 2, [227.0, 258.0, 394.0, 450.0], 0)
+ """
+ if not isinstance(other, BlockMatrix):
+ raise TypeError("Other should be a BlockMatrix, got %s" % type(other))
+
+ other_java_block_matrix = other._java_matrix_wrapper._java_model
+ java_block_matrix = self._java_matrix_wrapper.call("multiply", other_java_block_matrix)
+ return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock)
+
def toLocalMatrix(self):
"""
Collect the distributed matrix on the driver as a DenseMatrix.