From 3bdbbc6c972567861044dd6a6dc82f35cd12442d Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Tue, 27 Oct 2015 11:05:14 -0700 Subject: [SPARK-6488][MLLIB][PYTHON] Support addition/multiplication in PySpark's BlockMatrix This PR adds addition and multiplication to PySpark's `BlockMatrix` class via `add` and `multiply` functions. Author: Mike Dusenberry Closes #9139 from dusenberrymw/SPARK-6488_Add_Addition_and_Multiplication_to_PySpark_BlockMatrix. --- python/pyspark/mllib/linalg/distributed.py | 68 ++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) (limited to 'python') diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index aec407de90..0e76050788 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -775,6 +775,74 @@ class BlockMatrix(DistributedMatrix): """ return self._java_matrix_wrapper.call("numCols") + def add(self, other): + """ + Adds two block matrices together. The matrices must have the + same size and matching `rowsPerBlock` and `colsPerBlock` values. + If one of the sub matrix blocks that are being added is a + SparseMatrix, the resulting sub matrix block will also be a + SparseMatrix, even if it is being added to a DenseMatrix. If + two dense sub matrix blocks are added, the output block will + also be a DenseMatrix. + + >>> dm1 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + >>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]) + >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12]) + >>> blocks1 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)]) + >>> blocks2 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)]) + >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm2)]) + >>> mat1 = BlockMatrix(blocks1, 3, 2) + >>> mat2 = BlockMatrix(blocks2, 3, 2) + >>> mat3 = BlockMatrix(blocks3, 3, 2) + + >>> mat1.add(mat2).toLocalMatrix() + DenseMatrix(6, 2, [2.0, 4.0, 6.0, 14.0, 16.0, 18.0, 8.0, 10.0, 12.0, 20.0, 22.0, 24.0], 0) + + >>> mat1.add(mat3).toLocalMatrix() + DenseMatrix(6, 2, [8.0, 2.0, 3.0, 14.0, 16.0, 18.0, 4.0, 16.0, 18.0, 20.0, 22.0, 24.0], 0) + """ + if not isinstance(other, BlockMatrix): + raise TypeError("Other should be a BlockMatrix, got %s" % type(other)) + + other_java_block_matrix = other._java_matrix_wrapper._java_model + java_block_matrix = self._java_matrix_wrapper.call("add", other_java_block_matrix) + return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) + + def multiply(self, other): + """ + Left multiplies this BlockMatrix by `other`, another + BlockMatrix. The `colsPerBlock` of this matrix must equal the + `rowsPerBlock` of `other`. If `other` contains any SparseMatrix + blocks, they will have to be converted to DenseMatrix blocks. + The output BlockMatrix will only consist of DenseMatrix blocks. + This may cause some performance issues until support for + multiplying two sparse matrices is added. + + >>> dm1 = Matrices.dense(2, 3, [1, 2, 3, 4, 5, 6]) + >>> dm2 = Matrices.dense(2, 3, [7, 8, 9, 10, 11, 12]) + >>> dm3 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + >>> dm4 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]) + >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12]) + >>> blocks1 = sc.parallelize([((0, 0), dm1), ((0, 1), dm2)]) + >>> blocks2 = sc.parallelize([((0, 0), dm3), ((1, 0), dm4)]) + >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm4)]) + >>> mat1 = BlockMatrix(blocks1, 2, 3) + >>> mat2 = BlockMatrix(blocks2, 3, 2) + >>> mat3 = BlockMatrix(blocks3, 3, 2) + + >>> mat1.multiply(mat2).toLocalMatrix() + DenseMatrix(2, 2, [242.0, 272.0, 350.0, 398.0], 0) + + >>> mat1.multiply(mat3).toLocalMatrix() + DenseMatrix(2, 2, [227.0, 258.0, 394.0, 450.0], 0) + """ + if not isinstance(other, BlockMatrix): + raise TypeError("Other should be a BlockMatrix, got %s" % type(other)) + + other_java_block_matrix = other._java_matrix_wrapper._java_model + java_block_matrix = self._java_matrix_wrapper.call("multiply", other_java_block_matrix) + return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) + def toLocalMatrix(self): """ Collect the distributed matrix on the driver as a DenseMatrix. -- cgit v1.2.3