aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-06-17 11:10:16 -0700
committerDavies Liu <davies@databricks.com>2015-06-17 11:10:16 -0700
commit6765ef98dff070768bbcd585d341ee7664fbe76c (patch)
tree9e88fed33b78a098e8fb91b276eefe0957644e72 /python/pyspark/mllib/tests.py
parent104f30c36f3d44b7567f6f77adb92e0a96494541 (diff)
downloadspark-6765ef98dff070768bbcd585d341ee7664fbe76c.tar.gz
spark-6765ef98dff070768bbcd585d341ee7664fbe76c.tar.bz2
spark-6765ef98dff070768bbcd585d341ee7664fbe76c.zip
[SPARK-6390] [SQL] [MLlib] Port MatrixUDT to PySpark
MatrixUDT was recently coded in scala. This has been ported to PySpark Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6354 from MechCoder/spark-6390 and squashes the following commits: fc4dc1e [MechCoder] Better error message c940a44 [MechCoder] Added test aa9c391 [MechCoder] Add pyUDT to MatrixUDT 62a2a7d [MechCoder] [SPARK-6390] Port MatrixUDT to PySpark
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 36a4c7a540..f4c997261e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -39,7 +39,7 @@ else:
from pyspark import SparkContext
from pyspark.mllib.common import _to_java_object_rdd
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
- DenseMatrix, SparseMatrix, Vectors, Matrices
+ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
@@ -507,6 +507,38 @@ class VectorUDTTests(MLlibTestCase):
raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
+class MatrixUDTTests(MLlibTestCase):
+
+ dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
+ dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
+ sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
+ sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
+ udt = MatrixUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
+ self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
+
+ def test_infer_schema(self):
+ sqlCtx = SQLContext(self.sc)
+ rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
+ df = rdd.toDF()
+ schema = df.schema
+ self.assertTrue(schema.fields[1].dataType, self.udt)
+ matrices = df.map(lambda x: x._2).collect()
+ self.assertEqual(len(matrices), 2)
+ for m in matrices:
+ if isinstance(m, DenseMatrix):
+ self.assertTrue(m, self.dm1)
+ elif isinstance(m, SparseMatrix):
+ self.assertTrue(m, self.sm1)
+ else:
+ raise ValueError("Expected a matrix but got type %r" % type(m))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(MLlibTestCase):