aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-04-21 14:36:50 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-21 14:36:50 -0700
commit45c47fa4176ea75886a58f5d73c44afcb29aa629 (patch)
tree68601b1683fe06ababd86884d5a92d406097c553 /python/pyspark/mllib/tests.py
parentc25ca7c5a1f2a4f88f40b0c5cdbfa927c186cfa8 (diff)
downloadspark-45c47fa4176ea75886a58f5d73c44afcb29aa629.tar.gz
spark-45c47fa4176ea75886a58f5d73c44afcb29aa629.tar.bz2
spark-45c47fa4176ea75886a58f5d73c44afcb29aa629.zip
[SPARK-6845] [MLlib] [PySpark] Add isTranposed flag to DenseMatrix
Since sparse matrices now support a isTransposed flag for row major data, DenseMatrices should do the same. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5455 from MechCoder/spark-6845 and squashes the following commits: 525c370 [MechCoder] minor 004a37f [MechCoder] Cast boolean to int 151f3b6 [MechCoder] [WIP] Add isTransposed to pickle DenseMatrix cc0b90a [MechCoder] [SPARK-6845] Add isTranposed flag to DenseMatrix
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 849c88341a..8f89e2cee0 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -195,6 +195,22 @@ class VectorTests(PySparkTestCase):
self.assertEquals(expected[i][j], sm1t[i, j])
self.assertTrue(array_equal(sm1t.toArray(), expected))
+ def test_dense_matrix_is_transposed(self):
+ mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True)
+ mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9])
+ self.assertEquals(mat1, mat)
+
+ expected = [[0, 4], [1, 6], [3, 9]]
+ for i in range(3):
+ for j in range(2):
+ self.assertEquals(mat1[i, j], expected[i][j])
+ self.assertTrue(array_equal(mat1.toArray(), expected))
+
+ sm = mat1.toSparse()
+ self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2]))
+ self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
+ self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
+
class ListTests(PySparkTestCase):