From 56096dbaa8cb3ab39bfc2ce5827192313613b010 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 14 Oct 2014 14:42:09 -0700 Subject: SPARK-3803 [MLLIB] ArrayIndexOutOfBoundsException found in executing computePrincipalComponents Avoid overflow in computing n*(n+1)/2 as much as possible; throw explicit error when Gramian computation will fail due to negative array size; warn about large result when computing Gramian too Author: Sean Owen Closes #2801 from srowen/SPARK-3803 and squashes the following commits: b4e6d92 [Sean Owen] Avoid overflow in computing n*(n+1)/2 as much as possible; throw explicit error when Gramian computation will fail due to negative array size; warn about large result when computing Gramian too --- .../spark/mllib/linalg/distributed/RowMatrix.scala | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 8380058cf9..ec2d481dcc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -111,7 +111,10 @@ class RowMatrix( */ def computeGramianMatrix(): Matrix = { val n = numCols().toInt - val nt: Int = n * (n + 1) / 2 + checkNumColumns(n) + // Computes n*(n+1)/2, avoiding overflow in the multiplication. + // This succeeds when n <= 65535, which is checked above + val nt: Int = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( @@ -123,6 +126,16 @@ class RowMatrix( RowMatrix.triuToFull(n, GU.data) } + private def checkNumColumns(cols: Int): Unit = { + if (cols > 65535) { + throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols") + } + if (cols > 10000) { + val mem = cols * cols * 8 + logWarning(s"$cols columns will require at least $mem bytes of memory!") + } + } + /** * Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This * will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k @@ -301,12 +314,7 @@ class RowMatrix( */ def computeCovariance(): Matrix = { val n = numCols().toInt - - if (n > 10000) { - val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE - logWarning(s"The number of columns $n is greater than 10000! " + - s"We need at least $mem bytes of memory.") - } + checkNumColumns(n) val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), -- cgit v1.2.3