diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-04-21 14:36:50 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-04-21 14:36:50 -0700 |
commit | 45c47fa4176ea75886a58f5d73c44afcb29aa629 (patch) | |
tree | 68601b1683fe06ababd86884d5a92d406097c553 /python/pyspark/mllib/tests.py | |
parent | c25ca7c5a1f2a4f88f40b0c5cdbfa927c186cfa8 (diff) | |
download | spark-45c47fa4176ea75886a58f5d73c44afcb29aa629.tar.gz spark-45c47fa4176ea75886a58f5d73c44afcb29aa629.tar.bz2 spark-45c47fa4176ea75886a58f5d73c44afcb29aa629.zip |
[SPARK-6845] [MLlib] [PySpark] Add isTranposed flag to DenseMatrix
Since sparse matrices now support a isTransposed flag for row major data, DenseMatrices should do the same.
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #5455 from MechCoder/spark-6845 and squashes the following commits:
525c370 [MechCoder] minor
004a37f [MechCoder] Cast boolean to int
151f3b6 [MechCoder] [WIP] Add isTransposed to pickle DenseMatrix
cc0b90a [MechCoder] [SPARK-6845] Add isTranposed flag to DenseMatrix
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r-- | python/pyspark/mllib/tests.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 849c88341a..8f89e2cee0 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -195,6 +195,22 @@ class VectorTests(PySparkTestCase): self.assertEquals(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEquals(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEquals(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + class ListTests(PySparkTestCase): |