diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-04-01 17:03:39 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-04-01 17:03:39 -0700 |
commit | 2fa3b47dbf38aae58514473932c69bbd35de4e4c (patch) | |
tree | 7eb93451435a07fdbdde65469cb3c5291a0c3655 | |
parent | ccafd757eda478913f783f3127be715bf6413740 (diff) | |
download | spark-2fa3b47dbf38aae58514473932c69bbd35de4e4c.tar.gz spark-2fa3b47dbf38aae58514473932c69bbd35de4e4c.tar.bz2 spark-2fa3b47dbf38aae58514473932c69bbd35de4e4c.zip |
[SPARK-6576] [MLlib] [PySpark] DenseMatrix in PySpark should support indexing
Support indexing in DenseMatrices in PySpark
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #5232 from MechCoder/SPARK-6576 and squashes the following commits:
a735078 [MechCoder] Change bounds
a062025 [MechCoder] Matrices are stored in column order
7917bc1 [MechCoder] [SPARK-6576] DenseMatrix in PySpark should support indexing
-rw-r--r-- | python/pyspark/mllib/linalg.py | 10 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 7 |
2 files changed, 17 insertions, 0 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 8b791ff6a7..51c1490b16 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -670,6 +670,16 @@ class DenseMatrix(Matrix): """ return self.values.reshape((self.numRows, self.numCols), order='F') + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j >= self.numCols or j < 0: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + return self.values[i + j * self.numRows] + def __eq__(self, other): return (isinstance(other, DenseMatrix) and self.numRows == other.numRows and diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3bb0f0ca68..893fc6f491 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -135,6 +135,13 @@ class VectorTests(PySparkTestCase): for ind in [4, -5, 7.8]: self.assertRaises(ValueError, sv.__getitem__, ind) + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEquals(mat[i, j], expected[i][j]) + class ListTests(PySparkTestCase): |