diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-06-17 11:10:16 -0700 |
---|---|---|
committer | Davies Liu <davies@databricks.com> | 2015-06-17 11:10:16 -0700 |
commit | 6765ef98dff070768bbcd585d341ee7664fbe76c (patch) | |
tree | 9e88fed33b78a098e8fb91b276eefe0957644e72 /python/pyspark/mllib/tests.py | |
parent | 104f30c36f3d44b7567f6f77adb92e0a96494541 (diff) | |
download | spark-6765ef98dff070768bbcd585d341ee7664fbe76c.tar.gz spark-6765ef98dff070768bbcd585d341ee7664fbe76c.tar.bz2 spark-6765ef98dff070768bbcd585d341ee7664fbe76c.zip |
[SPARK-6390] [SQL] [MLlib] Port MatrixUDT to PySpark
MatrixUDT was recently coded in scala. This has been ported to PySpark
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #6354 from MechCoder/spark-6390 and squashes the following commits:
fc4dc1e [MechCoder] Better error message
c940a44 [MechCoder] Added test
aa9c391 [MechCoder] Add pyUDT to MatrixUDT
62a2a7d [MechCoder] [SPARK-6390] Port MatrixUDT to PySpark
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r-- | python/pyspark/mllib/tests.py | 34 |
1 files changed, 33 insertions, 1 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 36a4c7a540..f4c997261e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -39,7 +39,7 @@ else: from pyspark import SparkContext from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, SparseMatrix, Vectors, Matrices + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics @@ -507,6 +507,38 @@ class VectorUDTTests(MLlibTestCase): raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(MLlibTestCase): |