aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/linalg
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-14 21:37:43 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-14 21:37:43 -0700
commit4ae4d54794778042b2cc983e52757edac02412ab (patch)
tree095520b9ef546740fd0625de7c7572d29a79c743 /python/pyspark/mllib/linalg
parent55204181004c105c7a3e8c31a099b37e48bfd953 (diff)
downloadspark-4ae4d54794778042b2cc983e52757edac02412ab.tar.gz
spark-4ae4d54794778042b2cc983e52757edac02412ab.tar.bz2
spark-4ae4d54794778042b2cc983e52757edac02412ab.zip
[SPARK-9793] [MLLIB] [PYSPARK] PySpark DenseVector, SparseVector implement __eq__ and __hash__ correctly
PySpark DenseVector, SparseVector ```__eq__``` method should use semantics equality, and DenseVector can compared with SparseVector. Implement PySpark DenseVector, SparseVector ```__hash__``` method based on the first 16 entries. That will make PySpark Vector objects can be used in collections. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8166 from yanboliang/spark-9793.
Diffstat (limited to 'python/pyspark/mllib/linalg')
-rw-r--r--python/pyspark/mllib/linalg/__init__.py90
1 files changed, 75 insertions, 15 deletions
diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py
index 334dc8e38b..380f86e9b4 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -25,6 +25,7 @@ SciPy is available in their environment.
import sys
import array
+import struct
if sys.version >= '3':
basestring = str
@@ -122,6 +123,13 @@ def _format_float_list(l):
return [_format_float(x) for x in l]
+def _double_to_long_bits(value):
+ if np.isnan(value):
+ value = float('nan')
+ # pack double into 64 bits, then unpack as long int
+ return struct.unpack('Q', struct.pack('d', value))[0]
+
+
class VectorUDT(UserDefinedType):
"""
SQL user-defined type (UDT) for Vector.
@@ -404,11 +412,31 @@ class DenseVector(Vector):
return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array))
def __eq__(self, other):
- return isinstance(other, DenseVector) and np.array_equal(self.array, other.array)
+ if isinstance(other, DenseVector):
+ return np.array_equal(self.array, other.array)
+ elif isinstance(other, SparseVector):
+ if len(self) != other.size:
+ return False
+ return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values)
+ return False
def __ne__(self, other):
return not self == other
+ def __hash__(self):
+ size = len(self)
+ result = 31 + size
+ nnz = 0
+ i = 0
+ while i < size and nnz < 128:
+ if self.array[i] != 0:
+ result = 31 * result + i
+ bits = _double_to_long_bits(self.array[i])
+ result = 31 * result + (bits ^ (bits >> 32))
+ nnz += 1
+ i += 1
+ return result
+
def __getattr__(self, item):
return getattr(self.array, item)
@@ -704,20 +732,14 @@ class SparseVector(Vector):
return "SparseVector({0}, {{{1}}})".format(self.size, entries)
def __eq__(self, other):
- """
- Test SparseVectors for equality.
-
- >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)])
- >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
- >>> v1 == v2
- True
- >>> v1 != v2
- False
- """
- return (isinstance(other, self.__class__)
- and other.size == self.size
- and np.array_equal(other.indices, self.indices)
- and np.array_equal(other.values, self.values))
+ if isinstance(other, SparseVector):
+ return other.size == self.size and np.array_equal(other.indices, self.indices) \
+ and np.array_equal(other.values, self.values)
+ elif isinstance(other, DenseVector):
+ if self.size != len(other):
+ return False
+ return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array)
+ return False
def __getitem__(self, index):
inds = self.indices
@@ -739,6 +761,19 @@ class SparseVector(Vector):
def __ne__(self, other):
return not self.__eq__(other)
+ def __hash__(self):
+ result = 31 + self.size
+ nnz = 0
+ i = 0
+ while i < len(self.values) and nnz < 128:
+ if self.values[i] != 0:
+ result = 31 * result + int(self.indices[i])
+ bits = _double_to_long_bits(self.values[i])
+ result = 31 * result + (bits ^ (bits >> 32))
+ nnz += 1
+ i += 1
+ return result
+
class Vectors(object):
@@ -841,6 +876,31 @@ class Vectors(object):
def zeros(size):
return DenseVector(np.zeros(size))
+ @staticmethod
+ def _equals(v1_indices, v1_values, v2_indices, v2_values):
+ """
+ Check equality between sparse/dense vectors,
+ v1_indices and v2_indices assume to be strictly increasing.
+ """
+ v1_size = len(v1_values)
+ v2_size = len(v2_values)
+ k1 = 0
+ k2 = 0
+ all_equal = True
+ while all_equal:
+ while k1 < v1_size and v1_values[k1] == 0:
+ k1 += 1
+ while k2 < v2_size and v2_values[k2] == 0:
+ k2 += 1
+
+ if k1 >= v1_size or k2 >= v2_size:
+ return k1 >= v1_size and k2 >= v2_size
+
+ all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2]
+ k1 += 1
+ k2 += 1
+ return all_equal
+
class Matrix(object):