diff options
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r-- | python/pyspark/mllib/tests.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 9fa4d6f6a2..8332f8e061 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -33,7 +33,8 @@ if sys.version_info[:2] <= (2, 6): else: import unittest -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ + DenseMatrix from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics @@ -62,6 +63,7 @@ def _squared_distance(a, b): class VectorTests(PySparkTestCase): def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec))) self.assertEqual(v, nv) @@ -75,6 +77,8 @@ class VectorTests(PySparkTestCase): self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) self._test_serialize(DenseVector(pyarray.array('d', range(10)))) self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) def test_dot(self): sv = SparseVector(4, {1: 1, 3: 2}) |