aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-14 21:37:43 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-14 21:37:43 -0700
commit4ae4d54794778042b2cc983e52757edac02412ab (patch)
tree095520b9ef546740fd0625de7c7572d29a79c743 /python
parent55204181004c105c7a3e8c31a099b37e48bfd953 (diff)
downloadspark-4ae4d54794778042b2cc983e52757edac02412ab.tar.gz
spark-4ae4d54794778042b2cc983e52757edac02412ab.tar.bz2
spark-4ae4d54794778042b2cc983e52757edac02412ab.zip
[SPARK-9793] [MLLIB] [PYSPARK] PySpark DenseVector, SparseVector implement __eq__ and __hash__ correctly
PySpark DenseVector, SparseVector ```__eq__``` method should use semantics equality, and DenseVector can compared with SparseVector. Implement PySpark DenseVector, SparseVector ```__hash__``` method based on the first 16 entries. That will make PySpark Vector objects can be used in collections. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8166 from yanboliang/spark-9793.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/linalg/__init__.py90
-rw-r--r--python/pyspark/mllib/tests.py32
2 files changed, 107 insertions, 15 deletions
diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py
index 334dc8e38b..380f86e9b4 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -25,6 +25,7 @@ SciPy is available in their environment.
import sys
import array
+import struct
if sys.version >= '3':
basestring = str
@@ -122,6 +123,13 @@ def _format_float_list(l):
return [_format_float(x) for x in l]
+def _double_to_long_bits(value):
+ if np.isnan(value):
+ value = float('nan')
+ # pack double into 64 bits, then unpack as long int
+ return struct.unpack('Q', struct.pack('d', value))[0]
+
+
class VectorUDT(UserDefinedType):
"""
SQL user-defined type (UDT) for Vector.
@@ -404,11 +412,31 @@ class DenseVector(Vector):
return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array))
def __eq__(self, other):
- return isinstance(other, DenseVector) and np.array_equal(self.array, other.array)
+ if isinstance(other, DenseVector):
+ return np.array_equal(self.array, other.array)
+ elif isinstance(other, SparseVector):
+ if len(self) != other.size:
+ return False
+ return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values)
+ return False
def __ne__(self, other):
return not self == other
+ def __hash__(self):
+ size = len(self)
+ result = 31 + size
+ nnz = 0
+ i = 0
+ while i < size and nnz < 128:
+ if self.array[i] != 0:
+ result = 31 * result + i
+ bits = _double_to_long_bits(self.array[i])
+ result = 31 * result + (bits ^ (bits >> 32))
+ nnz += 1
+ i += 1
+ return result
+
def __getattr__(self, item):
return getattr(self.array, item)
@@ -704,20 +732,14 @@ class SparseVector(Vector):
return "SparseVector({0}, {{{1}}})".format(self.size, entries)
def __eq__(self, other):
- """
- Test SparseVectors for equality.
-
- >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)])
- >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
- >>> v1 == v2
- True
- >>> v1 != v2
- False
- """
- return (isinstance(other, self.__class__)
- and other.size == self.size
- and np.array_equal(other.indices, self.indices)
- and np.array_equal(other.values, self.values))
+ if isinstance(other, SparseVector):
+ return other.size == self.size and np.array_equal(other.indices, self.indices) \
+ and np.array_equal(other.values, self.values)
+ elif isinstance(other, DenseVector):
+ if self.size != len(other):
+ return False
+ return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array)
+ return False
def __getitem__(self, index):
inds = self.indices
@@ -739,6 +761,19 @@ class SparseVector(Vector):
def __ne__(self, other):
return not self.__eq__(other)
+ def __hash__(self):
+ result = 31 + self.size
+ nnz = 0
+ i = 0
+ while i < len(self.values) and nnz < 128:
+ if self.values[i] != 0:
+ result = 31 * result + int(self.indices[i])
+ bits = _double_to_long_bits(self.values[i])
+ result = 31 * result + (bits ^ (bits >> 32))
+ nnz += 1
+ i += 1
+ return result
+
class Vectors(object):
@@ -841,6 +876,31 @@ class Vectors(object):
def zeros(size):
return DenseVector(np.zeros(size))
+ @staticmethod
+ def _equals(v1_indices, v1_values, v2_indices, v2_values):
+ """
+ Check equality between sparse/dense vectors,
+ v1_indices and v2_indices assume to be strictly increasing.
+ """
+ v1_size = len(v1_values)
+ v2_size = len(v2_values)
+ k1 = 0
+ k2 = 0
+ all_equal = True
+ while all_equal:
+ while k1 < v1_size and v1_values[k1] == 0:
+ k1 += 1
+ while k2 < v2_size and v2_values[k2] == 0:
+ k2 += 1
+
+ if k1 >= v1_size or k2 >= v2_size:
+ return k1 >= v1_size and k2 >= v2_size
+
+ all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2]
+ k1 += 1
+ k2 += 1
+ return all_equal
+
class Matrix(object):
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 5097c5e8ba..636f9a06ca 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -194,6 +194,38 @@ class VectorTests(MLlibTestCase):
self.assertEquals(3.0, _squared_distance(sv, arr))
self.assertEquals(3.0, _squared_distance(sv, narr))
+ def test_hash(self):
+ v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+ v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v4 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+ self.assertEquals(hash(v1), hash(v2))
+ self.assertEquals(hash(v1), hash(v3))
+ self.assertEquals(hash(v2), hash(v3))
+ self.assertFalse(hash(v1) == hash(v4))
+ self.assertFalse(hash(v2) == hash(v4))
+
+ def test_eq(self):
+ v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+ v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
+ v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
+ v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+ self.assertEquals(v1, v2)
+ self.assertEquals(v1, v3)
+ self.assertFalse(v2 == v4)
+ self.assertFalse(v1 == v5)
+ self.assertFalse(v1 == v6)
+
+ def test_equals(self):
+ indices = [1, 2, 4]
+ values = [1., 3., 2.]
+ self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.]))
+
def test_conversion(self):
# numpy arrays should be automatically upcast to float64
# tests for fix of [SPARK-5089]