aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala8
4 files changed, 19 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
index 36d8cadd2b..181f507516 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -102,6 +102,9 @@ class IndexedRowMatrix(
k: Int,
computeU: Boolean = false,
rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = {
+
+ val n = numCols().toInt
+ require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.")
val indices = rows.map(_.index)
val svd = toRowMatrix().computeSVD(k, computeU, rCond)
val U = if (computeU) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index fbd35e372f..d5abba6a4b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -212,7 +212,7 @@ class RowMatrix(
tol: Double,
mode: String): SingularValueDecomposition[RowMatrix, Matrix] = {
val n = numCols().toInt
- require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.")
+ require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.")
object SVDMode extends Enumeration {
val LocalARPACK, LocalLAPACK, DistARPACK = Value
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index e25bc02b06..741cd4997b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -113,6 +113,13 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
assert(closeToZero(U * brzDiag(s) * V.t - localA))
}
+ test("validate k in svd") {
+ val A = new IndexedRowMatrix(indexedRows)
+ intercept[IllegalArgumentException] {
+ A.computeSVD(-1)
+ }
+ }
+
def closeToZero(G: BDM[Double]): Boolean = {
G.valuesIterator.map(math.abs).sum < 1e-6
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index dbf55ff81c..3309713e91 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -171,6 +171,14 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
}
}
+ test("validate k in svd") {
+ for (mat <- Seq(denseMat, sparseMat)) {
+ intercept[IllegalArgumentException] {
+ mat.computeSVD(-1)
+ }
+ }
+ }
+
def closeToZero(G: BDM[Double]): Boolean = {
G.valuesIterator.map(math.abs).sum < 1e-6
}