aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-04-01 17:03:39 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-01 17:03:39 -0700
commit2fa3b47dbf38aae58514473932c69bbd35de4e4c (patch)
tree7eb93451435a07fdbdde65469cb3c5291a0c3655 /python
parentccafd757eda478913f783f3127be715bf6413740 (diff)
downloadspark-2fa3b47dbf38aae58514473932c69bbd35de4e4c.tar.gz
spark-2fa3b47dbf38aae58514473932c69bbd35de4e4c.tar.bz2
spark-2fa3b47dbf38aae58514473932c69bbd35de4e4c.zip
[SPARK-6576] [MLlib] [PySpark] DenseMatrix in PySpark should support indexing
Support indexing in DenseMatrices in PySpark Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5232 from MechCoder/SPARK-6576 and squashes the following commits: a735078 [MechCoder] Change bounds a062025 [MechCoder] Matrices are stored in column order 7917bc1 [MechCoder] [SPARK-6576] DenseMatrix in PySpark should support indexing
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/linalg.py10
-rw-r--r--python/pyspark/mllib/tests.py7
2 files changed, 17 insertions, 0 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 8b791ff6a7..51c1490b16 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -670,6 +670,16 @@ class DenseMatrix(Matrix):
"""
return self.values.reshape((self.numRows, self.numCols), order='F')
+ def __getitem__(self, indices):
+ i, j = indices
+ if i < 0 or i >= self.numRows:
+ raise ValueError("Row index %d is out of range [0, %d)"
+ % (i, self.numRows))
+ if j >= self.numCols or j < 0:
+ raise ValueError("Column index %d is out of range [0, %d)"
+ % (j, self.numCols))
+ return self.values[i + j * self.numRows]
+
def __eq__(self, other):
return (isinstance(other, DenseMatrix) and
self.numRows == other.numRows and
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 3bb0f0ca68..893fc6f491 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -135,6 +135,13 @@ class VectorTests(PySparkTestCase):
for ind in [4, -5, 7.8]:
self.assertRaises(ValueError, sv.__getitem__, ind)
+ def test_matrix_indexing(self):
+ mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
+ expected = [[0, 6], [1, 8], [4, 10]]
+ for i in range(3):
+ for j in range(2):
+ self.assertEquals(mat[i, j], expected[i][j])
+
class ListTests(PySparkTestCase):