aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-05-07 14:02:05 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-07 14:02:05 -0700
commit347a329a36c94ff37363e4dffcbd5a24dc6a6714 (patch)
treed762543329780cfff1f996a0e69f1de2ac925eb6 /python
parent88717ee4e7542ac8d5d2e5756c912dd390b37e88 (diff)
downloadspark-347a329a36c94ff37363e4dffcbd5a24dc6a6714.tar.gz
spark-347a329a36c94ff37363e4dffcbd5a24dc6a6714.tar.bz2
spark-347a329a36c94ff37363e4dffcbd5a24dc6a6714.zip
[SPARK-7328] [MLLIB] [PYSPARK] Pyspark.mllib.linalg.Vectors: Missing items
Add 1. Class methods squared_dist 3. parse 4. norm 5. numNonzeros 6. copy I made a few vectorizations wrt squared_dist and dot as well. I have added support for SparseMatrix serialization in a separate PR (https://github.com/apache/spark/pull/5775) and plan to complete support for Matrices in another PR. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5872 from MechCoder/local_linalg_api and squashes the following commits: a8ff1e0 [MechCoder] minor ce3e53e [MechCoder] Add error message for parser 1bd3c04 [MechCoder] Robust parser and removed unnecessary methods f779561 [MechCoder] [SPARK-7328] Pyspark.mllib.linalg.Vectors: Missing items
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/linalg.py148
-rw-r--r--python/pyspark/mllib/tests.py25
2 files changed, 171 insertions, 2 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 9f3b0baf9f..23d1a79ffe 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -208,9 +208,46 @@ class DenseVector(Vector):
ar = ar.astype(np.float64)
self.array = ar
+ @staticmethod
+ def parse(s):
+ """
+ Parse string representation back into the DenseVector.
+
+ >>> DenseVector.parse(' [ 0.0,1.0,2.0, 3.0]')
+ DenseVector([0.0, 1.0, 2.0, 3.0])
+ """
+ start = s.find('[')
+ if start == -1:
+ raise ValueError("Array should start with '['.")
+ end = s.find(']')
+ if end == -1:
+ raise ValueError("Array should end with ']'.")
+ s = s[start + 1: end]
+
+ try:
+ values = [float(val) for val in s.split(',')]
+ except ValueError:
+ raise ValueError("Unable to parse values from %s" % s)
+ return DenseVector(values)
+
def __reduce__(self):
return DenseVector, (self.array.tostring(),)
+ def numNonzeros(self):
+ return np.count_nonzero(self.array)
+
+ def norm(self, p):
+ """
+ Calculte the norm of a DenseVector.
+
+ >>> a = DenseVector([0, -1, 2, -3])
+ >>> a.norm(2)
+ 3.7...
+ >>> a.norm(1)
+ 6.0
+ """
+ return np.linalg.norm(self.array, p)
+
def dot(self, other):
"""
Compute the dot product of two Vectors. We support
@@ -387,8 +424,74 @@ class SparseVector(Vector):
if self.indices[i] >= self.indices[i + 1]:
raise TypeError("indices array must be sorted")
+ def numNonzeros(self):
+ return np.count_nonzero(self.values)
+
+ def norm(self, p):
+ """
+ Calculte the norm of a SparseVector.
+
+ >>> a = SparseVector(4, [0, 1], [3., -4.])
+ >>> a.norm(1)
+ 7.0
+ >>> a.norm(2)
+ 5.0
+ """
+ return np.linalg.norm(self.values, p)
+
def __reduce__(self):
- return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring()))
+ return (
+ SparseVector,
+ (self.size, self.indices.tostring(), self.values.tostring()))
+
+ @staticmethod
+ def parse(s):
+ """
+ Parse string representation back into the DenseVector.
+
+ >>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )')
+ SparseVector(4, {0: 4.0, 1: 5.0})
+ """
+ start = s.find('(')
+ if start == -1:
+ raise ValueError("Tuple should start with '('")
+ end = s.find(')')
+ if start == -1:
+ raise ValueError("Tuple should end with ')'")
+ s = s[start + 1: end].strip()
+
+ size = s[: s.find(',')]
+ try:
+ size = int(size)
+ except ValueError:
+ raise ValueError("Cannot parse size %s." % size)
+
+ ind_start = s.find('[')
+ if ind_start == -1:
+ raise ValueError("Indices array should start with '['.")
+ ind_end = s.find(']')
+ if ind_end == -1:
+ raise ValueError("Indices array should end with ']'")
+ new_s = s[ind_start + 1: ind_end]
+ ind_list = new_s.split(',')
+ try:
+ indices = [int(ind) for ind in ind_list]
+ except ValueError:
+ raise ValueError("Unable to parse indices from %s." % new_s)
+ s = s[ind_end + 1:].strip()
+
+ val_start = s.find('[')
+ if val_start == -1:
+ raise ValueError("Values array should start with '['.")
+ val_end = s.find(']')
+ if val_end == -1:
+ raise ValueError("Values array should end with ']'.")
+ val_list = s[val_start + 1: val_end].split(',')
+ try:
+ values = [float(val) for val in val_list]
+ except ValueError:
+ raise ValueError("Unable to parse values from %s." % s)
+ return SparseVector(size, indices, values)
def dot(self, other):
"""
@@ -633,6 +736,49 @@ class Vectors(object):
"""
return str(vector)
+ @staticmethod
+ def squared_distance(v1, v2):
+ """
+ Squared distance between two vectors.
+ a and b can be of type SparseVector, DenseVector, np.ndarray
+ or array.array.
+
+ >>> a = Vectors.sparse(4, [(0, 1), (3, 4)])
+ >>> b = Vectors.dense([2, 5, 4, 1])
+ >>> a.squared_distance(b)
+ 51.0
+ """
+ v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2)
+ return v1.squared_distance(v2)
+
+ @staticmethod
+ def norm(vector, p):
+ """
+ Find norm of the given vector.
+ """
+ return _convert_to_vector(vector).norm(p)
+
+ @staticmethod
+ def parse(s):
+ """Parse a string representation back into the Vector.
+
+ >>> Vectors.parse('[2,1,2 ]')
+ DenseVector([2.0, 1.0, 2.0])
+ >>> Vectors.parse(' ( 100, [0], [2])')
+ SparseVector(100, {0: 2.0})
+ """
+ if s.find('(') == -1 and s.find('[') != -1:
+ return DenseVector.parse(s)
+ elif s.find('(') != -1:
+ return SparseVector.parse(s)
+ else:
+ raise ValueError(
+ "Cannot find tokens '[' or '(' from the input string.")
+
+ @staticmethod
+ def zeros(size):
+ return DenseVector(np.zeros(size))
+
class Matrix(object):
"""
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d05cfe2af0..36a4c7a540 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -24,7 +24,7 @@ import sys
import tempfile
import array as pyarray
-from numpy import array, array_equal, zeros
+from numpy import array, array_equal, zeros, inf
from py4j.protocol import Py4JJavaError
if sys.version_info[:2] <= (2, 6):
@@ -220,6 +220,29 @@ class VectorTests(MLlibTestCase):
self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
+ def test_parse_vector(self):
+ a = DenseVector([3, 4, 6, 7])
+ self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
+ self.assertTrue(Vectors.parse(str(a)), a)
+ a = SparseVector(4, [0, 2], [3, 4])
+ self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
+ self.assertTrue(Vectors.parse(str(a)), a)
+ a = SparseVector(10, [0, 1], [4, 5])
+ self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
+
+ def test_norms(self):
+ a = DenseVector([0, 2, 3, -1])
+ self.assertAlmostEqual(a.norm(2), 3.742, 3)
+ self.assertTrue(a.norm(1), 6)
+ self.assertTrue(a.norm(inf), 3)
+ a = SparseVector(4, [0, 2], [3, -4])
+ self.assertAlmostEqual(a.norm(2), 5)
+ self.assertTrue(a.norm(1), 7)
+ self.assertTrue(a.norm(inf), 4)
+
+ tmp = SparseVector(4, [0, 2], [3, 0])
+ self.assertEqual(tmp.numNonzeros(), 1)
+
class ListTests(MLlibTestCase):