aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d05cfe2af0..36a4c7a540 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -24,7 +24,7 @@ import sys
import tempfile
import array as pyarray
-from numpy import array, array_equal, zeros
+from numpy import array, array_equal, zeros, inf
from py4j.protocol import Py4JJavaError
if sys.version_info[:2] <= (2, 6):
@@ -220,6 +220,29 @@ class VectorTests(MLlibTestCase):
self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
+ def test_parse_vector(self):
+ a = DenseVector([3, 4, 6, 7])
+ self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
+ self.assertTrue(Vectors.parse(str(a)), a)
+ a = SparseVector(4, [0, 2], [3, 4])
+ self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
+ self.assertTrue(Vectors.parse(str(a)), a)
+ a = SparseVector(10, [0, 1], [4, 5])
+ self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
+
+ def test_norms(self):
+ a = DenseVector([0, 2, 3, -1])
+ self.assertAlmostEqual(a.norm(2), 3.742, 3)
+ self.assertTrue(a.norm(1), 6)
+ self.assertTrue(a.norm(inf), 3)
+ a = SparseVector(4, [0, 2], [3, -4])
+ self.assertAlmostEqual(a.norm(2), 5)
+ self.assertTrue(a.norm(1), 7)
+ self.assertTrue(a.norm(inf), 4)
+
+ tmp = SparseVector(4, [0, 2], [3, 0])
+ self.assertEqual(tmp.numNonzeros(), 1)
+
class ListTests(MLlibTestCase):