aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2014-10-14 14:42:09 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-14 14:42:09 -0700
commit56096dbaa8cb3ab39bfc2ce5827192313613b010 (patch)
treeaa6615e61ab5780b7d2c3e48ce8bb4372114307d
parent24b818b971ba715b6796518e4c6afdecb1b16f15 (diff)
downloadspark-56096dbaa8cb3ab39bfc2ce5827192313613b010.tar.gz
spark-56096dbaa8cb3ab39bfc2ce5827192313613b010.tar.bz2
spark-56096dbaa8cb3ab39bfc2ce5827192313613b010.zip
SPARK-3803 [MLLIB] ArrayIndexOutOfBoundsException found in executing computePrincipalComponents
Avoid overflow in computing n*(n+1)/2 as much as possible; throw explicit error when Gramian computation will fail due to negative array size; warn about large result when computing Gramian too Author: Sean Owen <sowen@cloudera.com> Closes #2801 from srowen/SPARK-3803 and squashes the following commits: b4e6d92 [Sean Owen] Avoid overflow in computing n*(n+1)/2 as much as possible; throw explicit error when Gramian computation will fail due to negative array size; warn about large result when computing Gramian too
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala22
1 files changed, 15 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 8380058cf9..ec2d481dcc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -111,7 +111,10 @@ class RowMatrix(
*/
def computeGramianMatrix(): Matrix = {
val n = numCols().toInt
- val nt: Int = n * (n + 1) / 2
+ checkNumColumns(n)
+ // Computes n*(n+1)/2, avoiding overflow in the multiplication.
+ // This succeeds when n <= 65535, which is checked above
+ val nt: Int = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2))
// Compute the upper triangular part of the gram matrix.
val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
@@ -123,6 +126,16 @@ class RowMatrix(
RowMatrix.triuToFull(n, GU.data)
}
+ private def checkNumColumns(cols: Int): Unit = {
+ if (cols > 65535) {
+ throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols")
+ }
+ if (cols > 10000) {
+ val mem = cols * cols * 8
+ logWarning(s"$cols columns will require at least $mem bytes of memory!")
+ }
+ }
+
/**
* Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This
* will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k
@@ -301,12 +314,7 @@ class RowMatrix(
*/
def computeCovariance(): Matrix = {
val n = numCols().toInt
-
- if (n > 10000) {
- val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE
- logWarning(s"The number of columns $n is greater than 10000! " +
- s"We need at least $mem bytes of memory.")
- }
+ checkNumColumns(n)
val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze),