aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2013-08-12 22:08:36 -0700
committerHolden Karau <holden@pigscanfly.ca>2013-08-12 22:08:36 -0700
commit705c9ace2a893168aadfca7d80749f3597d9a24a (patch)
tree33718d2119892a07d6f3e8955bb925c807b6bb65 /mllib
parente2fdac60da8cb9b0ff0191631bf7e37ad3a47c76 (diff)
downloadspark-705c9ace2a893168aadfca7d80749f3597d9a24a.tar.gz
spark-705c9ace2a893168aadfca7d80749f3597d9a24a.tar.bz2
spark-705c9ace2a893168aadfca7d80749f3597d9a24a.zip
Use less instances of the random class during ALS setup
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/recommendation/ALS.scala21
1 files changed, 14 insertions, 7 deletions
diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
index 6c71dc1f32..974046d260 100644
--- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
@@ -124,9 +124,18 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
// Initialize user and product factors randomly
- val seed = new Random().nextInt()
- var users = userOutLinks.mapValues(_.elementIds.map(u => randomFactor(rank, seed ^ u)))
- var products = productOutLinks.mapValues(_.elementIds.map(p => randomFactor(rank, seed ^ ~p)))
+ var users = userOutLinks.mapPartitions(itr => {
+ val rand = new Random()
+ itr.map({case (x,y) =>
+ (x,y.elementIds.map(u => randomFactor(rank, rand)))
+ })
+ })
+ var products = productOutLinks.mapPartitions(itr => {
+ val rand = new Random()
+ itr.map({case (x,y) =>
+ (x,y.elementIds.map(u => randomFactor(rank, rand)))
+ })
+ })
for (iter <- 0 until iterations) {
// perform ALS update
@@ -213,11 +222,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
}
/**
- * Make a random factor vector with the given seed.
- * TODO: Initialize things using mapPartitionsWithIndex to make it faster?
+ * Make a random factor vector with the given random.
*/
- private def randomFactor(rank: Int, seed: Int): Array[Double] = {
- val rand = new Random(seed)
+ private def randomFactor(rank: Int, rand: Random): Array[Double] = {
Array.fill(rank)(rand.nextDouble)
}