aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala31
1 files changed, 29 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
index f737d2c51a..f37eaf225a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.mllib.linalg.distributed
import java.{util => ju}
-import breeze.linalg.{DenseMatrix => BDM}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV}
import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix}
+import org.apache.spark.mllib.linalg.{DenseMatrix, DenseVector, Matrices, Matrix, SparseMatrix, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -134,6 +134,33 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(rowMat.numRows() === m)
assert(rowMat.numCols() === n)
assert(rowMat.toBreeze() === gridBasedMat.toBreeze())
+
+ val rows = 1
+ val cols = 10
+
+ val matDense = new DenseMatrix(rows, cols,
+ Array(1.0, 1.0, 3.0, 2.0, 5.0, 6.0, 7.0, 1.0, 2.0, 3.0))
+ val matSparse = new SparseMatrix(rows, cols,
+ Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1), Array(0), Array(1.0))
+
+ val vectors: Seq[((Int, Int), Matrix)] = Seq(
+ ((0, 0), matDense),
+ ((1, 0), matSparse))
+
+ val rdd = sc.parallelize(vectors)
+ val B = new BlockMatrix(rdd, rows, cols)
+
+ val C = B.toIndexedRowMatrix.rows.collect
+
+ (C(0).vector.toBreeze, C(1).vector.toBreeze) match {
+ case (denseVector: BDV[Double], sparseVector: BSV[Double]) =>
+ assert(denseVector.length === sparseVector.length)
+
+ assert(matDense.toArray === denseVector.toArray)
+ assert(matSparse.toArray === sparseVector.toArray)
+ case _ =>
+ throw new RuntimeException("IndexedRow returns vectors of unexpected type")
+ }
}
test("toBreeze and toLocalMatrix") {