aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 36a4c7a540..f4c997261e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -39,7 +39,7 @@ else:
from pyspark import SparkContext
from pyspark.mllib.common import _to_java_object_rdd
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
- DenseMatrix, SparseMatrix, Vectors, Matrices
+ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
@@ -507,6 +507,38 @@ class VectorUDTTests(MLlibTestCase):
raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
+class MatrixUDTTests(MLlibTestCase):
+
+ dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
+ dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
+ sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
+ sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
+ udt = MatrixUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
+ self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
+
+ def test_infer_schema(self):
+ sqlCtx = SQLContext(self.sc)
+ rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
+ df = rdd.toDF()
+ schema = df.schema
+ self.assertTrue(schema.fields[1].dataType, self.udt)
+ matrices = df.map(lambda x: x._2).collect()
+ self.assertEqual(len(matrices), 2)
+ for m in matrices:
+ if isinstance(m, DenseMatrix):
+ self.assertTrue(m, self.dm1)
+ elif isinstance(m, SparseMatrix):
+ self.assertTrue(m, self.sm1)
+ else:
+ raise ValueError("Expected a matrix but got type %r" % type(m))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(MLlibTestCase):