aboutsummaryrefslogtreecommitdiff
path: root/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala')
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala11
1 files changed, 7 insertions, 4 deletions
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
index 0ea687bbcc..f1ecc65af1 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
@@ -454,10 +454,13 @@ class SparseMatrix @Since("2.0.0") (
require(values.length == rowIndices.length, "The number of row indices and values don't match! " +
s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}")
- // The Or statement is for the case when the matrix is transposed
- require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " +
- "column indices should be the number of columns + 1. Currently, colPointers.length: " +
- s"${colPtrs.length}, numCols: $numCols")
+ if (isTransposed) {
+ require(colPtrs.length == numRows + 1,
+ s"Expecting ${numRows + 1} colPtrs when numRows = $numRows but got ${colPtrs.length}")
+ } else {
+ require(colPtrs.length == numCols + 1,
+ s"Expecting ${numCols + 1} colPtrs when numCols = $numCols but got ${colPtrs.length}")
+ }
require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " +
s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}")