aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMatei Zaharia <matei.zaharia@gmail.com>2013-08-22 15:57:28 -0700
committerMatei Zaharia <matei.zaharia@gmail.com>2013-08-22 15:57:28 -0700
commit46ea0c1b47022b84372396256971157afc07c814 (patch)
tree3125964b26e8707e7ecbc53957a152cbf3421ab5 /mllib
parent9ac3d62cacae34e742c70e7006ffaf7e21880802 (diff)
parent8fc40818d714651c0fb360a26b64a3ab12559961 (diff)
downloadspark-46ea0c1b47022b84372396256971157afc07c814.tar.gz
spark-46ea0c1b47022b84372396256971157afc07c814.tar.bz2
spark-46ea0c1b47022b84372396256971157afc07c814.zip
Merge pull request #814 from holdenk/master
Create less instances of the random class during ALS initialization.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/recommendation/ALS.scala21
1 files changed, 14 insertions, 7 deletions
diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
index 6c71dc1f32..9097f46db9 100644
--- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
@@ -124,9 +124,18 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
// Initialize user and product factors randomly
- val seed = new Random().nextInt()
- var users = userOutLinks.mapValues(_.elementIds.map(u => randomFactor(rank, seed ^ u)))
- var products = productOutLinks.mapValues(_.elementIds.map(p => randomFactor(rank, seed ^ ~p)))
+ var users = userOutLinks.mapPartitions {itr =>
+ val rand = new Random()
+ itr.map({case (x, y) =>
+ (x, y.elementIds.map(u => randomFactor(rank, rand)))
+ })
+ }
+ var products = productOutLinks.mapPartitions {itr =>
+ val rand = new Random()
+ itr.map({case (x, y) =>
+ (x, y.elementIds.map(u => randomFactor(rank, rand)))
+ })
+ }
for (iter <- 0 until iterations) {
// perform ALS update
@@ -213,11 +222,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
}
/**
- * Make a random factor vector with the given seed.
- * TODO: Initialize things using mapPartitionsWithIndex to make it faster?
+ * Make a random factor vector with the given random.
*/
- private def randomFactor(rank: Int, seed: Int): Array[Double] = {
- val rand = new Random(seed)
+ private def randomFactor(rank: Int, rand: Random): Array[Double] = {
Array.fill(rank)(rand.nextDouble)
}