From 6765ef98dff070768bbcd585d341ee7664fbe76c Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 17 Jun 2015 11:10:16 -0700 Subject: [SPARK-6390] [SQL] [MLlib] Port MatrixUDT to PySpark MatrixUDT was recently coded in scala. This has been ported to PySpark Author: MechCoder Closes #6354 from MechCoder/spark-6390 and squashes the following commits: fc4dc1e [MechCoder] Better error message c940a44 [MechCoder] Added test aa9c391 [MechCoder] Add pyUDT to MatrixUDT 62a2a7d [MechCoder] [SPARK-6390] Port MatrixUDT to PySpark --- python/pyspark/mllib/tests.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) (limited to 'python/pyspark/mllib/tests.py') diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 36a4c7a540..f4c997261e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -39,7 +39,7 @@ else: from pyspark import SparkContext from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, SparseMatrix, Vectors, Matrices + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics @@ -507,6 +507,38 @@ class VectorUDTTests(MLlibTestCase): raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(MLlibTestCase): -- cgit v1.2.3