aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-08-22 16:13:46 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-08-22 16:13:46 -0700
commit215c13dd41d8500835ef00624a0b4ced2253554e (patch)
treeff5e0daf28e67e826ade49a5ccd63ecc340c7310 /mllib
parent46ea0c1b47022b84372396256971157afc07c814 (diff)
downloadspark-215c13dd41d8500835ef00624a0b4ced2253554e.tar.gz
spark-215c13dd41d8500835ef00624a0b4ced2253554e.tar.bz2
spark-215c13dd41d8500835ef00624a0b4ced2253554e.zip
Fix code style and a nondeterministic RDD issue in ALS
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/recommendation/ALS.scala31
1 files changed, 20 insertions, 11 deletions
diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
index 9097f46db9..dbfbf59975 100644
--- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
@@ -123,18 +123,27 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock)
val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
- // Initialize user and product factors randomly
- var users = userOutLinks.mapPartitions {itr =>
- val rand = new Random()
- itr.map({case (x, y) =>
- (x, y.elementIds.map(u => randomFactor(rank, rand)))
- })
+ // Initialize user and product factors randomly, but use a deterministic seed for each partition
+ // so that fault recovery works
+ val seedGen = new Random()
+ val seed1 = seedGen.nextInt()
+ val seed2 = seedGen.nextInt()
+ // Hash an integer to propagate random bits at all positions, similar to java.util.HashTable
+ def hash(x: Int): Int = {
+ val r = x ^ (x >>> 20) ^ (x >>> 12)
+ r ^ (r >>> 7) ^ (r >>> 4)
}
- var products = productOutLinks.mapPartitions {itr =>
- val rand = new Random()
- itr.map({case (x, y) =>
- (x, y.elementIds.map(u => randomFactor(rank, rand)))
- })
+ var users = userOutLinks.mapPartitionsWithIndex { (index, itr) =>
+ val rand = new Random(hash(seed1 ^ index))
+ itr.map { case (x, y) =>
+ (x, y.elementIds.map(_ => randomFactor(rank, rand)))
+ }
+ }
+ var products = productOutLinks.mapPartitionsWithIndex { (index, itr) =>
+ val rand = new Random(hash(seed2 ^ index))
+ itr.map { case (x, y) =>
+ (x, y.elementIds.map(_ => randomFactor(rank, rand)))
+ }
}
for (iter <- 0 until iterations) {