From 215c13dd41d8500835ef00624a0b4ced2253554e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 22 Aug 2013 16:13:46 -0700 Subject: Fix code style and a nondeterministic RDD issue in ALS --- .../scala/spark/mllib/recommendation/ALS.scala | 31 ++++++++++++++-------- 1 file changed, 20 insertions(+), 11 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala index 9097f46db9..dbfbf59975 100644 --- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala @@ -123,18 +123,27 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock) val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock) - // Initialize user and product factors randomly - var users = userOutLinks.mapPartitions {itr => - val rand = new Random() - itr.map({case (x, y) => - (x, y.elementIds.map(u => randomFactor(rank, rand))) - }) + // Initialize user and product factors randomly, but use a deterministic seed for each partition + // so that fault recovery works + val seedGen = new Random() + val seed1 = seedGen.nextInt() + val seed2 = seedGen.nextInt() + // Hash an integer to propagate random bits at all positions, similar to java.util.HashTable + def hash(x: Int): Int = { + val r = x ^ (x >>> 20) ^ (x >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) } - var products = productOutLinks.mapPartitions {itr => - val rand = new Random() - itr.map({case (x, y) => - (x, y.elementIds.map(u => randomFactor(rank, rand))) - }) + var users = userOutLinks.mapPartitionsWithIndex { (index, itr) => + val rand = new Random(hash(seed1 ^ index)) + itr.map { case (x, y) => + (x, y.elementIds.map(_ => randomFactor(rank, rand))) + } + } + var products = productOutLinks.mapPartitionsWithIndex { (index, itr) => + val rand = new Random(hash(seed2 ^ index)) + itr.map { case (x, y) => + (x, y.elementIds.map(_ => randomFactor(rank, rand))) + } } for (iter <- 0 until iterations) { -- cgit v1.2.3