aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-14 21:37:43 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-14 21:37:43 -0700
commit4ae4d54794778042b2cc983e52757edac02412ab (patch)
tree095520b9ef546740fd0625de7c7572d29a79c743 /python/pyspark/mllib/tests.py
parent55204181004c105c7a3e8c31a099b37e48bfd953 (diff)
downloadspark-4ae4d54794778042b2cc983e52757edac02412ab.tar.gz
spark-4ae4d54794778042b2cc983e52757edac02412ab.tar.bz2
spark-4ae4d54794778042b2cc983e52757edac02412ab.zip
[SPARK-9793] [MLLIB] [PYSPARK] PySpark DenseVector, SparseVector implement __eq__ and __hash__ correctly
PySpark DenseVector, SparseVector ```__eq__``` method should use semantics equality, and DenseVector can compared with SparseVector. Implement PySpark DenseVector, SparseVector ```__hash__``` method based on the first 16 entries. That will make PySpark Vector objects can be used in collections. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8166 from yanboliang/spark-9793.
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 5097c5e8ba..636f9a06ca 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -194,6 +194,38 @@ class VectorTests(MLlibTestCase):
self.assertEquals(3.0, _squared_distance(sv, arr))
self.assertEquals(3.0, _squared_distance(sv, narr))
+ def test_hash(self):
+ v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+ v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v4 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+ self.assertEquals(hash(v1), hash(v2))
+ self.assertEquals(hash(v1), hash(v3))
+ self.assertEquals(hash(v2), hash(v3))
+ self.assertFalse(hash(v1) == hash(v4))
+ self.assertFalse(hash(v2) == hash(v4))
+
+ def test_eq(self):
+ v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+ v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
+ v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
+ v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+ self.assertEquals(v1, v2)
+ self.assertEquals(v1, v3)
+ self.assertFalse(v2 == v4)
+ self.assertFalse(v1 == v5)
+ self.assertFalse(v1 == v6)
+
+ def test_equals(self):
+ indices = [1, 2, 4]
+ values = [1., 3., 2.]
+ self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.]))
+
def test_conversion(self):
# numpy arrays should be automatically upcast to float64
# tests for fix of [SPARK-5089]