aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala13
1 files changed, 13 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index c098b5458f..96f677db3f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -424,4 +424,17 @@ class MatricesSuite extends FunSuite {
assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1))
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
}
+
+ test("MatrixUDT") {
+ val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8))
+ val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0))
+ val dm3 = new DenseMatrix(0, 0, Array())
+ val sm1 = dm1.toSparse
+ val sm2 = dm2.toSparse
+ val sm3 = dm3.toSparse
+ val mUDT = new MatrixUDT()
+ Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach {
+ mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray)
+ }
+ }
}