diff options
author | Ehsan M.Kermani <ehsanmo1367@gmail.com> | 2016-03-14 19:17:09 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-14 19:17:09 -0700 |
commit | 992142b87ed5b507493e4f9fac3f72ba14fafbbc (patch) | |
tree | 366bf4b9c707c0b06dd4dd5bff26148eed5617cb /mllib/src/test | |
parent | 06dec37455c3f800897defee6fad0da623f26050 (diff) | |
download | spark-992142b87ed5b507493e4f9fac3f72ba14fafbbc.tar.gz spark-992142b87ed5b507493e4f9fac3f72ba14fafbbc.tar.bz2 spark-992142b87ed5b507493e4f9fac3f72ba14fafbbc.zip |
[SPARK-11826][MLLIB] Refactor add() and subtract() methods
srowen Could you please check this when you have time?
Author: Ehsan M.Kermani <ehsanmo1367@gmail.com>
Closes #9916 from ehsanmok/JIRA-11826.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index d91ba8a6fd..f737d2c51a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -192,6 +192,49 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(sparseBM.add(sparseBM).toBreeze() === sparseBM.add(denseBM).toBreeze()) } + test("subtract") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 0), new DenseMatrix(1, 2, Array(1.0, 0.0))), // Added block that doesn't exist in A + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val B = new BlockMatrix(rdd, rowPerPart, colPerPart) + + val expected = BDM( + (0.0, 0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 0.0), + (-1.0, 0.0, 0.0, 0.0)) + + val AsubtractB = gridBasedMat.subtract(B) + assert(AsubtractB.numRows() === m) + assert(AsubtractB.numCols() === B.numCols()) + assert(AsubtractB.toBreeze() === expected) + + val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m, n + 1) // columns don't match + intercept[IllegalArgumentException] { + gridBasedMat.subtract(C) + } + val largerBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(4, 4, new Array[Double](16))), + ((1, 0), new DenseMatrix(1, 4, Array(1.0, 0.0, 1.0, 5.0)))) + val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4, m, n) + intercept[SparkException] { // partitioning doesn't match + gridBasedMat.subtract(C2) + } + // subtracting BlockMatrices composed of SparseMatrices + val sparseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), SparseMatrix.speye(4)) + val denseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), DenseMatrix.eye(4)) + val sparseBM = new BlockMatrix(sc.makeRDD(sparseBlocks, 4), 4, 4, 8, 8) + val denseBM = new BlockMatrix(sc.makeRDD(denseBlocks, 4), 4, 4, 8, 8) + + assert(sparseBM.subtract(sparseBM).toBreeze() === sparseBM.subtract(denseBM).toBreeze()) + } + test("multiply") { // identity matrix val blocks: Seq[((Int, Int), Matrix)] = Seq( |