aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 9fa4d6f6a2..8332f8e061 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -33,7 +33,8 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
+ DenseMatrix
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
@@ -62,6 +63,7 @@ def _squared_distance(a, b):
class VectorTests(PySparkTestCase):
def _test_serialize(self, v):
+ self.assertEqual(v, ser.loads(ser.dumps(v)))
jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec)))
self.assertEqual(v, nv)
@@ -75,6 +77,8 @@ class VectorTests(PySparkTestCase):
self._test_serialize(DenseVector(array([1., 2., 3., 4.])))
self._test_serialize(DenseVector(pyarray.array('d', range(10))))
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
+ self._test_serialize(SparseVector(3, {}))
+ self._test_serialize(DenseMatrix(2, 3, range(6)))
def test_dot(self):
sv = SparseVector(4, {1: 1, 3: 2})