aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-05-07 14:02:05 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-07 14:02:05 -0700
commit347a329a36c94ff37363e4dffcbd5a24dc6a6714 (patch)
treed762543329780cfff1f996a0e69f1de2ac925eb6 /python/pyspark/mllib/tests.py
parent88717ee4e7542ac8d5d2e5756c912dd390b37e88 (diff)
downloadspark-347a329a36c94ff37363e4dffcbd5a24dc6a6714.tar.gz
spark-347a329a36c94ff37363e4dffcbd5a24dc6a6714.tar.bz2
spark-347a329a36c94ff37363e4dffcbd5a24dc6a6714.zip
[SPARK-7328] [MLLIB] [PYSPARK] Pyspark.mllib.linalg.Vectors: Missing items
Add 1. Class methods squared_dist 3. parse 4. norm 5. numNonzeros 6. copy I made a few vectorizations wrt squared_dist and dot as well. I have added support for SparseMatrix serialization in a separate PR (https://github.com/apache/spark/pull/5775) and plan to complete support for Matrices in another PR. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5872 from MechCoder/local_linalg_api and squashes the following commits: a8ff1e0 [MechCoder] minor ce3e53e [MechCoder] Add error message for parser 1bd3c04 [MechCoder] Robust parser and removed unnecessary methods f779561 [MechCoder] [SPARK-7328] Pyspark.mllib.linalg.Vectors: Missing items
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d05cfe2af0..36a4c7a540 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -24,7 +24,7 @@ import sys
import tempfile
import array as pyarray
-from numpy import array, array_equal, zeros
+from numpy import array, array_equal, zeros, inf
from py4j.protocol import Py4JJavaError
if sys.version_info[:2] <= (2, 6):
@@ -220,6 +220,29 @@ class VectorTests(MLlibTestCase):
self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
+ def test_parse_vector(self):
+ a = DenseVector([3, 4, 6, 7])
+ self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
+ self.assertTrue(Vectors.parse(str(a)), a)
+ a = SparseVector(4, [0, 2], [3, 4])
+ self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
+ self.assertTrue(Vectors.parse(str(a)), a)
+ a = SparseVector(10, [0, 1], [4, 5])
+ self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
+
+ def test_norms(self):
+ a = DenseVector([0, 2, 3, -1])
+ self.assertAlmostEqual(a.norm(2), 3.742, 3)
+ self.assertTrue(a.norm(1), 6)
+ self.assertTrue(a.norm(inf), 3)
+ a = SparseVector(4, [0, 2], [3, -4])
+ self.assertAlmostEqual(a.norm(2), 5)
+ self.assertTrue(a.norm(1), 7)
+ self.assertTrue(a.norm(inf), 4)
+
+ tmp = SparseVector(4, [0, 2], [3, 0])
+ self.assertEqual(tmp.numNonzeros(), 1)
+
class ListTests(MLlibTestCase):