From 2b40365d76b7d9d382ad5077cdf979906bca17f2 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 8 Jul 2015 13:19:27 -0700 Subject: [SPARK-7785] [MLLIB] [PYSPARK] Add __str__ and __repr__ to Matrices Adding __str__ and __repr__ to DenseMatrix and SparseMatrix Author: MechCoder Closes #6342 from MechCoder/spark-7785 and squashes the following commits: 7b9a82c [MechCoder] Add tests for greater than 16 elements b88e9dd [MechCoder] Increment limit to 16 1425a01 [MechCoder] Change tests 36bd166 [MechCoder] Change str and repr representation 97f0da9 [MechCoder] zip is same as izip in python3 94ca4b2 [MechCoder] Added doctests and iterate over values instead of colPtrs b26fa89 [MechCoder] minor 394dde9 [MechCoder] [SPARK-7785] Add __str__ and __repr__ to Matrices --- python/pyspark/mllib/linalg.py | 127 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) (limited to 'python/pyspark/mllib/linalg.py') diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 12d8dbbb92..51ac198305 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -31,6 +31,7 @@ if sys.version >= '3': xrange = range import copyreg as copy_reg else: + from itertools import izip as zip import copy_reg import numpy as np @@ -116,6 +117,10 @@ def _format_float(f, digits=4): return s +def _format_float_list(l): + return [_format_float(x) for x in l] + + class VectorUDT(UserDefinedType): """ SQL user-defined type (UDT) for Vector. @@ -870,6 +875,50 @@ class DenseMatrix(Matrix): self.numRows, self.numCols, self.values.tostring(), int(self.isTransposed)) + def __str__(self): + """ + Pretty printing of a DenseMatrix + + >>> dm = DenseMatrix(2, 2, range(4)) + >>> print(dm) + DenseMatrix([[ 0., 2.], + [ 1., 3.]]) + >>> dm = DenseMatrix(2, 2, range(4), isTransposed=True) + >>> print(dm) + DenseMatrix([[ 0., 1.], + [ 2., 3.]]) + """ + # Inspired by __repr__ in scipy matrices. + array_lines = repr(self.toArray()).splitlines() + + # We need to adjust six spaces which is the difference in number + # of letters between "DenseMatrix" and "array" + x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]]) + return array_lines[0].replace("array", "DenseMatrix") + "\n" + x + + def __repr__(self): + """ + Representation of a DenseMatrix + + >>> dm = DenseMatrix(2, 2, range(4)) + >>> dm + DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False) + """ + # If the number of values are less than seventeen then return as it is. + # Else return first eight values and last eight values. + if len(self.values) < 17: + entries = _format_float_list(self.values) + else: + entries = ( + _format_float_list(self.values[:8]) + + ["..."] + + _format_float_list(self.values[-8:]) + ) + + entries = ", ".join(entries) + return "DenseMatrix({0}, {1}, [{2}], {3})".format( + self.numRows, self.numCols, entries, self.isTransposed) + def toArray(self): """ Return an numpy.ndarray @@ -946,6 +995,84 @@ class SparseMatrix(Matrix): raise ValueError("Expected rowIndices of length %d, got %d." % (self.rowIndices.size, self.values.size)) + def __str__(self): + """ + Pretty printing of a SparseMatrix + + >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + >>> print(sm1) + 2 X 2 CSCMatrix + (0,0) 2.0 + (1,0) 3.0 + (1,1) 4.0 + >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + >>> print(sm1) + 2 X 2 CSRMatrix + (0,0) 2.0 + (0,1) 3.0 + (1,1) 4.0 + """ + spstr = "{0} X {1} ".format(self.numRows, self.numCols) + if self.isTransposed: + spstr += "CSRMatrix\n" + else: + spstr += "CSCMatrix\n" + + cur_col = 0 + smlist = [] + + # Display first 16 values. + if len(self.values) <= 16: + zipindval = zip(self.rowIndices, self.values) + else: + zipindval = zip(self.rowIndices[:16], self.values[:16]) + for i, (rowInd, value) in enumerate(zipindval): + if self.colPtrs[cur_col + 1] <= i: + cur_col += 1 + if self.isTransposed: + smlist.append('({0},{1}) {2}'.format( + cur_col, rowInd, _format_float(value))) + else: + smlist.append('({0},{1}) {2}'.format( + rowInd, cur_col, _format_float(value))) + spstr += "\n".join(smlist) + + if len(self.values) > 16: + spstr += "\n.." * 2 + return spstr + + def __repr__(self): + """ + Representation of a SparseMatrix + + >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + >>> sm1 + SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False) + """ + rowIndices = list(self.rowIndices) + colPtrs = list(self.colPtrs) + + if len(self.values) <= 16: + values = _format_float_list(self.values) + + else: + values = ( + _format_float_list(self.values[:8]) + + ["..."] + + _format_float_list(self.values[-8:]) + ) + rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:] + + if len(self.colPtrs) > 16: + colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:] + + values = ", ".join(values) + rowIndices = ", ".join([str(ind) for ind in rowIndices]) + colPtrs = ", ".join([str(ptr) for ptr in colPtrs]) + return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format( + self.numRows, self.numCols, colPtrs, rowIndices, + values, self.isTransposed) + def __reduce__(self): return SparseMatrix, ( self.numRows, self.numCols, self.colPtrs.tostring(), -- cgit v1.2.3