aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/linalg.py
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-08 13:19:27 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-08 13:19:27 -0700
commit2b40365d76b7d9d382ad5077cdf979906bca17f2 (patch)
tree858363ee47b81a0f5a048fe66995a41d664ee76b /python/pyspark/mllib/linalg.py
parent374c8a8a4a8ac4171d312a6c31080a6724e55c60 (diff)
downloadspark-2b40365d76b7d9d382ad5077cdf979906bca17f2.tar.gz
spark-2b40365d76b7d9d382ad5077cdf979906bca17f2.tar.bz2
spark-2b40365d76b7d9d382ad5077cdf979906bca17f2.zip
[SPARK-7785] [MLLIB] [PYSPARK] Add __str__ and __repr__ to Matrices
Adding __str__ and __repr__ to DenseMatrix and SparseMatrix Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6342 from MechCoder/spark-7785 and squashes the following commits: 7b9a82c [MechCoder] Add tests for greater than 16 elements b88e9dd [MechCoder] Increment limit to 16 1425a01 [MechCoder] Change tests 36bd166 [MechCoder] Change str and repr representation 97f0da9 [MechCoder] zip is same as izip in python3 94ca4b2 [MechCoder] Added doctests and iterate over values instead of colPtrs b26fa89 [MechCoder] minor 394dde9 [MechCoder] [SPARK-7785] Add __str__ and __repr__ to Matrices
Diffstat (limited to 'python/pyspark/mllib/linalg.py')
-rw-r--r--python/pyspark/mllib/linalg.py127
1 files changed, 127 insertions, 0 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 12d8dbbb92..51ac198305 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -31,6 +31,7 @@ if sys.version >= '3':
xrange = range
import copyreg as copy_reg
else:
+ from itertools import izip as zip
import copy_reg
import numpy as np
@@ -116,6 +117,10 @@ def _format_float(f, digits=4):
return s
+def _format_float_list(l):
+ return [_format_float(x) for x in l]
+
+
class VectorUDT(UserDefinedType):
"""
SQL user-defined type (UDT) for Vector.
@@ -870,6 +875,50 @@ class DenseMatrix(Matrix):
self.numRows, self.numCols, self.values.tostring(),
int(self.isTransposed))
+ def __str__(self):
+ """
+ Pretty printing of a DenseMatrix
+
+ >>> dm = DenseMatrix(2, 2, range(4))
+ >>> print(dm)
+ DenseMatrix([[ 0., 2.],
+ [ 1., 3.]])
+ >>> dm = DenseMatrix(2, 2, range(4), isTransposed=True)
+ >>> print(dm)
+ DenseMatrix([[ 0., 1.],
+ [ 2., 3.]])
+ """
+ # Inspired by __repr__ in scipy matrices.
+ array_lines = repr(self.toArray()).splitlines()
+
+ # We need to adjust six spaces which is the difference in number
+ # of letters between "DenseMatrix" and "array"
+ x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]])
+ return array_lines[0].replace("array", "DenseMatrix") + "\n" + x
+
+ def __repr__(self):
+ """
+ Representation of a DenseMatrix
+
+ >>> dm = DenseMatrix(2, 2, range(4))
+ >>> dm
+ DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False)
+ """
+ # If the number of values are less than seventeen then return as it is.
+ # Else return first eight values and last eight values.
+ if len(self.values) < 17:
+ entries = _format_float_list(self.values)
+ else:
+ entries = (
+ _format_float_list(self.values[:8]) +
+ ["..."] +
+ _format_float_list(self.values[-8:])
+ )
+
+ entries = ", ".join(entries)
+ return "DenseMatrix({0}, {1}, [{2}], {3})".format(
+ self.numRows, self.numCols, entries, self.isTransposed)
+
def toArray(self):
"""
Return an numpy.ndarray
@@ -946,6 +995,84 @@ class SparseMatrix(Matrix):
raise ValueError("Expected rowIndices of length %d, got %d."
% (self.rowIndices.size, self.values.size))
+ def __str__(self):
+ """
+ Pretty printing of a SparseMatrix
+
+ >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ >>> print(sm1)
+ 2 X 2 CSCMatrix
+ (0,0) 2.0
+ (1,0) 3.0
+ (1,1) 4.0
+ >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ >>> print(sm1)
+ 2 X 2 CSRMatrix
+ (0,0) 2.0
+ (0,1) 3.0
+ (1,1) 4.0
+ """
+ spstr = "{0} X {1} ".format(self.numRows, self.numCols)
+ if self.isTransposed:
+ spstr += "CSRMatrix\n"
+ else:
+ spstr += "CSCMatrix\n"
+
+ cur_col = 0
+ smlist = []
+
+ # Display first 16 values.
+ if len(self.values) <= 16:
+ zipindval = zip(self.rowIndices, self.values)
+ else:
+ zipindval = zip(self.rowIndices[:16], self.values[:16])
+ for i, (rowInd, value) in enumerate(zipindval):
+ if self.colPtrs[cur_col + 1] <= i:
+ cur_col += 1
+ if self.isTransposed:
+ smlist.append('({0},{1}) {2}'.format(
+ cur_col, rowInd, _format_float(value)))
+ else:
+ smlist.append('({0},{1}) {2}'.format(
+ rowInd, cur_col, _format_float(value)))
+ spstr += "\n".join(smlist)
+
+ if len(self.values) > 16:
+ spstr += "\n.." * 2
+ return spstr
+
+ def __repr__(self):
+ """
+ Representation of a SparseMatrix
+
+ >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ >>> sm1
+ SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False)
+ """
+ rowIndices = list(self.rowIndices)
+ colPtrs = list(self.colPtrs)
+
+ if len(self.values) <= 16:
+ values = _format_float_list(self.values)
+
+ else:
+ values = (
+ _format_float_list(self.values[:8]) +
+ ["..."] +
+ _format_float_list(self.values[-8:])
+ )
+ rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:]
+
+ if len(self.colPtrs) > 16:
+ colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:]
+
+ values = ", ".join(values)
+ rowIndices = ", ".join([str(ind) for ind in rowIndices])
+ colPtrs = ", ".join([str(ptr) for ptr in colPtrs])
+ return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format(
+ self.numRows, self.numCols, colPtrs, rowIndices,
+ values, self.isTransposed)
+
def __reduce__(self):
return SparseMatrix, (
self.numRows, self.numCols, self.colPtrs.tostring(),