aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authororaviv <oraviv@paypal.com>2016-07-13 14:47:08 +0100
committerSean Owen <sowen@cloudera.com>2016-07-13 14:47:08 +0100
commitea06e4ef34c860219a9aeec81816ef53ada96253 (patch)
tree32fe745a7941c76a6044d12933dac5c6a4772cdf /mllib
parent51ade51a9fd64fc2fe651c505a286e6f29f59d40 (diff)
downloadspark-ea06e4ef34c860219a9aeec81816ef53ada96253.tar.gz
spark-ea06e4ef34c860219a9aeec81816ef53ada96253.tar.bz2
spark-ea06e4ef34c860219a9aeec81816ef53ada96253.zip
[SPARK-16469] enhanced simulate multiply
## What changes were proposed in this pull request? We have a use case of multiplying very big sparse matrices. we have about 1000x1000 distributed block matrices multiplication and the simulate multiply goes like O(n^4) (n being 1000). it takes about 1.5 hours. We modified it slightly with classical hashmap and now run in about 30 seconds O(n^2). ## How was this patch tested? We have added a performance test and verified the reduced time. Author: oraviv <oraviv@paypal.com> Closes #14068 from uzadude/master.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala13
1 files changed, 9 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index 639295c695..9782350587 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -426,16 +426,21 @@ class BlockMatrix @Since("1.3.0") (
partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = {
val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached
val rightMatrix = other.blocks.keys.collect()
+
+ val rightCounterpartsHelper = rightMatrix.groupBy(_._1).mapValues(_.map(_._2))
val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) =>
- val rightCounterparts = rightMatrix.filter(_._1 == colIndex)
- val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2)))
+ val rightCounterparts = rightCounterpartsHelper.getOrElse(colIndex, Array())
+ val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b)))
((rowIndex, colIndex), partitions.toSet)
}.toMap
+
+ val leftCounterpartsHelper = leftMatrix.groupBy(_._2).mapValues(_.map(_._1))
val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) =>
- val leftCounterparts = leftMatrix.filter(_._2 == rowIndex)
- val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex)))
+ val leftCounterparts = leftCounterpartsHelper.getOrElse(rowIndex, Array())
+ val partitions = leftCounterparts.map(b => partitioner.getPartition((b, colIndex)))
((rowIndex, colIndex), partitions.toSet)
}.toMap
+
(leftDestinations, rightDestinations)
}