aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorEhsan M.Kermani <ehsanmo1367@gmail.com>2016-03-14 19:17:09 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-14 19:17:09 -0700
commit992142b87ed5b507493e4f9fac3f72ba14fafbbc (patch)
tree366bf4b9c707c0b06dd4dd5bff26148eed5617cb /mllib/src/test/scala/org/apache
parent06dec37455c3f800897defee6fad0da623f26050 (diff)
downloadspark-992142b87ed5b507493e4f9fac3f72ba14fafbbc.tar.gz
spark-992142b87ed5b507493e4f9fac3f72ba14fafbbc.tar.bz2
spark-992142b87ed5b507493e4f9fac3f72ba14fafbbc.zip
[SPARK-11826][MLLIB] Refactor add() and subtract() methods
srowen Could you please check this when you have time? Author: Ehsan M.Kermani <ehsanmo1367@gmail.com> Closes #9916 from ehsanmok/JIRA-11826.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala43
1 files changed, 43 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
index d91ba8a6fd..f737d2c51a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
@@ -192,6 +192,49 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(sparseBM.add(sparseBM).toBreeze() === sparseBM.add(denseBM).toBreeze())
}
+ test("subtract") {
+ val blocks: Seq[((Int, Int), Matrix)] = Seq(
+ ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))),
+ ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))),
+ ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))),
+ ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))),
+ ((2, 0), new DenseMatrix(1, 2, Array(1.0, 0.0))), // Added block that doesn't exist in A
+ ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0))))
+ val rdd = sc.parallelize(blocks, numPartitions)
+ val B = new BlockMatrix(rdd, rowPerPart, colPerPart)
+
+ val expected = BDM(
+ (0.0, 0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0, 0.0),
+ (-1.0, 0.0, 0.0, 0.0))
+
+ val AsubtractB = gridBasedMat.subtract(B)
+ assert(AsubtractB.numRows() === m)
+ assert(AsubtractB.numCols() === B.numCols())
+ assert(AsubtractB.toBreeze() === expected)
+
+ val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m, n + 1) // columns don't match
+ intercept[IllegalArgumentException] {
+ gridBasedMat.subtract(C)
+ }
+ val largerBlocks: Seq[((Int, Int), Matrix)] = Seq(
+ ((0, 0), new DenseMatrix(4, 4, new Array[Double](16))),
+ ((1, 0), new DenseMatrix(1, 4, Array(1.0, 0.0, 1.0, 5.0))))
+ val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4, m, n)
+ intercept[SparkException] { // partitioning doesn't match
+ gridBasedMat.subtract(C2)
+ }
+ // subtracting BlockMatrices composed of SparseMatrices
+ val sparseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), SparseMatrix.speye(4))
+ val denseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), DenseMatrix.eye(4))
+ val sparseBM = new BlockMatrix(sc.makeRDD(sparseBlocks, 4), 4, 4, 8, 8)
+ val denseBM = new BlockMatrix(sc.makeRDD(denseBlocks, 4), 4, 4, 8, 8)
+
+ assert(sparseBM.subtract(sparseBM).toBreeze() === sparseBM.subtract(denseBM).toBreeze())
+ }
+
test("multiply") {
// identity matrix
val blocks: Seq[((Int, Int), Matrix)] = Seq(