aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-08 13:19:27 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-08 13:19:27 -0700
commit2b40365d76b7d9d382ad5077cdf979906bca17f2 (patch)
tree858363ee47b81a0f5a048fe66995a41d664ee76b /python/pyspark
parent374c8a8a4a8ac4171d312a6c31080a6724e55c60 (diff)
downloadspark-2b40365d76b7d9d382ad5077cdf979906bca17f2.tar.gz
spark-2b40365d76b7d9d382ad5077cdf979906bca17f2.tar.bz2
spark-2b40365d76b7d9d382ad5077cdf979906bca17f2.zip
[SPARK-7785] [MLLIB] [PYSPARK] Add __str__ and __repr__ to Matrices
Adding __str__ and __repr__ to DenseMatrix and SparseMatrix Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6342 from MechCoder/spark-7785 and squashes the following commits: 7b9a82c [MechCoder] Add tests for greater than 16 elements b88e9dd [MechCoder] Increment limit to 16 1425a01 [MechCoder] Change tests 36bd166 [MechCoder] Change str and repr representation 97f0da9 [MechCoder] zip is same as izip in python3 94ca4b2 [MechCoder] Added doctests and iterate over values instead of colPtrs b26fa89 [MechCoder] minor 394dde9 [MechCoder] [SPARK-7785] Add __str__ and __repr__ to Matrices
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/linalg.py127
-rw-r--r--python/pyspark/mllib/tests.py52
2 files changed, 178 insertions, 1 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 12d8dbbb92..51ac198305 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -31,6 +31,7 @@ if sys.version >= '3':
xrange = range
import copyreg as copy_reg
else:
+ from itertools import izip as zip
import copy_reg
import numpy as np
@@ -116,6 +117,10 @@ def _format_float(f, digits=4):
return s
+def _format_float_list(l):
+ return [_format_float(x) for x in l]
+
+
class VectorUDT(UserDefinedType):
"""
SQL user-defined type (UDT) for Vector.
@@ -870,6 +875,50 @@ class DenseMatrix(Matrix):
self.numRows, self.numCols, self.values.tostring(),
int(self.isTransposed))
+ def __str__(self):
+ """
+ Pretty printing of a DenseMatrix
+
+ >>> dm = DenseMatrix(2, 2, range(4))
+ >>> print(dm)
+ DenseMatrix([[ 0., 2.],
+ [ 1., 3.]])
+ >>> dm = DenseMatrix(2, 2, range(4), isTransposed=True)
+ >>> print(dm)
+ DenseMatrix([[ 0., 1.],
+ [ 2., 3.]])
+ """
+ # Inspired by __repr__ in scipy matrices.
+ array_lines = repr(self.toArray()).splitlines()
+
+ # We need to adjust six spaces which is the difference in number
+ # of letters between "DenseMatrix" and "array"
+ x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]])
+ return array_lines[0].replace("array", "DenseMatrix") + "\n" + x
+
+ def __repr__(self):
+ """
+ Representation of a DenseMatrix
+
+ >>> dm = DenseMatrix(2, 2, range(4))
+ >>> dm
+ DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False)
+ """
+ # If the number of values are less than seventeen then return as it is.
+ # Else return first eight values and last eight values.
+ if len(self.values) < 17:
+ entries = _format_float_list(self.values)
+ else:
+ entries = (
+ _format_float_list(self.values[:8]) +
+ ["..."] +
+ _format_float_list(self.values[-8:])
+ )
+
+ entries = ", ".join(entries)
+ return "DenseMatrix({0}, {1}, [{2}], {3})".format(
+ self.numRows, self.numCols, entries, self.isTransposed)
+
def toArray(self):
"""
Return an numpy.ndarray
@@ -946,6 +995,84 @@ class SparseMatrix(Matrix):
raise ValueError("Expected rowIndices of length %d, got %d."
% (self.rowIndices.size, self.values.size))
+ def __str__(self):
+ """
+ Pretty printing of a SparseMatrix
+
+ >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ >>> print(sm1)
+ 2 X 2 CSCMatrix
+ (0,0) 2.0
+ (1,0) 3.0
+ (1,1) 4.0
+ >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ >>> print(sm1)
+ 2 X 2 CSRMatrix
+ (0,0) 2.0
+ (0,1) 3.0
+ (1,1) 4.0
+ """
+ spstr = "{0} X {1} ".format(self.numRows, self.numCols)
+ if self.isTransposed:
+ spstr += "CSRMatrix\n"
+ else:
+ spstr += "CSCMatrix\n"
+
+ cur_col = 0
+ smlist = []
+
+ # Display first 16 values.
+ if len(self.values) <= 16:
+ zipindval = zip(self.rowIndices, self.values)
+ else:
+ zipindval = zip(self.rowIndices[:16], self.values[:16])
+ for i, (rowInd, value) in enumerate(zipindval):
+ if self.colPtrs[cur_col + 1] <= i:
+ cur_col += 1
+ if self.isTransposed:
+ smlist.append('({0},{1}) {2}'.format(
+ cur_col, rowInd, _format_float(value)))
+ else:
+ smlist.append('({0},{1}) {2}'.format(
+ rowInd, cur_col, _format_float(value)))
+ spstr += "\n".join(smlist)
+
+ if len(self.values) > 16:
+ spstr += "\n.." * 2
+ return spstr
+
+ def __repr__(self):
+ """
+ Representation of a SparseMatrix
+
+ >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ >>> sm1
+ SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False)
+ """
+ rowIndices = list(self.rowIndices)
+ colPtrs = list(self.colPtrs)
+
+ if len(self.values) <= 16:
+ values = _format_float_list(self.values)
+
+ else:
+ values = (
+ _format_float_list(self.values[:8]) +
+ ["..."] +
+ _format_float_list(self.values[-8:])
+ )
+ rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:]
+
+ if len(self.colPtrs) > 16:
+ colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:]
+
+ values = ", ".join(values)
+ rowIndices = ", ".join([str(ind) for ind in rowIndices])
+ colPtrs = ", ".join([str(ptr) for ptr in colPtrs])
+ return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format(
+ self.numRows, self.numCols, colPtrs, rowIndices,
+ values, self.isTransposed)
+
def __reduce__(self):
return SparseMatrix, (
self.numRows, self.numCols, self.colPtrs.tostring(),
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d9f9874d50..f2eab5b18f 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -27,7 +27,7 @@ from time import time, sleep
from shutil import rmtree
from numpy import (
- array, array_equal, zeros, inf, random, exp, dot, all, mean, abs)
+ array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones)
from numpy import sum as array_sum
from py4j.protocol import Py4JJavaError
@@ -189,6 +189,53 @@ class VectorTests(MLlibTestCase):
for j in range(2):
self.assertEquals(mat[i, j], expected[i][j])
+ def test_repr_dense_matrix(self):
+ mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
+ self.assertTrue(
+ repr(mat),
+ 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
+
+ mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True)
+ self.assertTrue(
+ repr(mat),
+ 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
+
+ mat = DenseMatrix(6, 3, zeros(18))
+ self.assertTrue(
+ repr(mat),
+ 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)')
+
+ def test_repr_sparse_matrix(self):
+ sm1t = SparseMatrix(
+ 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
+ isTransposed=True)
+ self.assertTrue(
+ repr(sm1t),
+ 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)')
+
+ indices = tile(arange(6), 3)
+ values = ones(18)
+ sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values)
+ self.assertTrue(
+ repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \
+ [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \
+ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)")
+
+ self.assertTrue(
+ str(sm),
+ "6 X 3 CSCMatrix\n\
+ (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\
+ (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\
+ (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..")
+
+ sm = SparseMatrix(1, 18, zeros(19), [], [])
+ self.assertTrue(
+ repr(sm),
+ 'SparseMatrix(1, 18, \
+ [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)')
+
def test_sparse_matrix(self):
# Test sparse matrix creation.
sm1 = SparseMatrix(
@@ -198,6 +245,9 @@ class VectorTests(MLlibTestCase):
self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4])
self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2])
self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0])
+ self.assertTrue(
+ repr(sm1),
+ 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)')
# Test indexing
expected = [