aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-06-16 23:02:46 +0200
committerSean Owen <sowen@cloudera.com>2016-06-16 23:02:46 +0200
commit36110a8306608186696c536028d2776e022d305a (patch)
tree36fe552ac1aeadb2ec8d7561bf3590407d64e279 /mllib
parentf9bf15d9bde4df2178f7a8f932c883bb77c46149 (diff)
downloadspark-36110a8306608186696c536028d2776e022d305a.tar.gz
spark-36110a8306608186696c536028d2776e022d305a.tar.bz2
spark-36110a8306608186696c536028d2776e022d305a.zip
[SPARK-15922][MLLIB] `toIndexedRowMatrix` should consider the case `cols < offset+colsPerBlock`
## What changes were proposed in this pull request? SPARK-15922 reports the following scenario throwing an exception due to the mismatched vector sizes. This PR handles the exceptional case, `cols < (offset + colsPerBlock)`. **Before** ```scala scala> import org.apache.spark.mllib.linalg.distributed._ scala> import org.apache.spark.mllib.linalg._ scala> val rows = IndexedRow(0L, new DenseVector(Array(1,2,3))) :: IndexedRow(1L, new DenseVector(Array(1,2,3))):: IndexedRow(2L, new DenseVector(Array(1,2,3))):: Nil scala> val rdd = sc.parallelize(rows) scala> val matrix = new IndexedRowMatrix(rdd, 3, 3) scala> val bmat = matrix.toBlockMatrix scala> val imat = bmat.toIndexedRowMatrix scala> imat.rows.collect ... // java.lang.IllegalArgumentException: requirement failed: Vectors must be the same length! ``` **After** ```scala ... scala> imat.rows.collect res0: Array[org.apache.spark.mllib.linalg.distributed.IndexedRow] = Array(IndexedRow(0,[1.0,2.0,3.0]), IndexedRow(1,[1.0,2.0,3.0]), IndexedRow(2,[1.0,2.0,3.0])) ``` ## How was this patch tested? Pass the Jenkins tests (including the above case) Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13643 from dongjoon-hyun/SPARK-15922.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala5
2 files changed, 6 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index 7a24617781..639295c695 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -288,7 +288,7 @@ class BlockMatrix @Since("1.3.0") (
vectors.foreach { case (blockColIdx: Int, vec: BV[Double]) =>
val offset = colsPerBlock * blockColIdx
- wholeVector(offset until offset + colsPerBlock) := vec
+ wholeVector(offset until Math.min(cols, offset + colsPerBlock)) := vec
}
new IndexedRow(rowIdx, Vectors.fromBreeze(wholeVector))
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
index e5a2cbbb58..61266f3c78 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
@@ -135,6 +135,11 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(rowMat.numCols() === n)
assert(rowMat.toBreeze() === gridBasedMat.toBreeze())
+ // SPARK-15922: BlockMatrix to IndexedRowMatrix throws an error"
+ val bmat = rowMat.toBlockMatrix
+ val imat = bmat.toIndexedRowMatrix
+ imat.rows.collect
+
val rows = 1
val cols = 10