diff options
-rw-r--r-- | mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala | 11 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 11 |
2 files changed, 14 insertions, 8 deletions
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala index 0ea687bbcc..f1ecc65af1 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala @@ -454,10 +454,13 @@ class SparseMatrix @Since("2.0.0") ( require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - // The Or statement is for the case when the matrix is transposed - require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + - "column indices should be the number of columns + 1. Currently, colPointers.length: " + - s"${colPtrs.length}, numCols: $numCols") + if (isTransposed) { + require(colPtrs.length == numRows + 1, + s"Expecting ${numRows + 1} colPtrs when numRows = $numRows but got ${colPtrs.length}") + } else { + require(colPtrs.length == numCols + 1, + s"Expecting ${numCols + 1} colPtrs when numCols = $numCols but got ${colPtrs.length}") + } require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index e8f34388cd..4c39cf17f4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -572,10 +572,13 @@ class SparseMatrix @Since("1.3.0") ( require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - // The Or statement is for the case when the matrix is transposed - require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + - "column indices should be the number of columns + 1. Currently, colPointers.length: " + - s"${colPtrs.length}, numCols: $numCols") + if (isTransposed) { + require(colPtrs.length == numRows + 1, + s"Expecting ${numRows + 1} colPtrs when numRows = $numRows but got ${colPtrs.length}") + } else { + require(colPtrs.length == numCols + 1, + s"Expecting ${numCols + 1} colPtrs when numCols = $numCols but got ${colPtrs.length}") + } require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") |