aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpeng.zhang <peng.zhang@xiaomi.com>2014-07-22 02:39:07 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-22 02:39:07 -0700
commit75db1742abf9e08111ddf8f330e6561c5520a86c (patch)
treeb3da6e8afc7231b2c772e1fdffaaf5f3039c4e53
parent81fec9922c5a1a44e086fba450c3eea03cddce63 (diff)
downloadspark-75db1742abf9e08111ddf8f330e6561c5520a86c.tar.gz
spark-75db1742abf9e08111ddf8f330e6561c5520a86c.tar.bz2
spark-75db1742abf9e08111ddf8f330e6561c5520a86c.zip
[SPARK-2612] [mllib] Fix data skew in ALS
Author: peng.zhang <peng.zhang@xiaomi.com> Closes #1521 from renozhang/fix-als and squashes the following commits: b5727a4 [peng.zhang] Remove no need argument 1a4f7a0 [peng.zhang] Fix data skew in ALS
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala11
1 files changed, 5 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index cc56fd6ef2..15e8855db6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -252,14 +252,14 @@ class ALS private (
val YtY = Some(sc.broadcast(computeYtY(users)))
val previousProducts = products
products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks,
- userPartitioner, rank, lambda, alpha, YtY)
+ rank, lambda, alpha, YtY)
previousProducts.unpersist()
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
products.setName(s"products-$iter").persist()
val XtX = Some(sc.broadcast(computeYtY(products)))
val previousUsers = users
users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks,
- productPartitioner, rank, lambda, alpha, XtX)
+ rank, lambda, alpha, XtX)
previousUsers.unpersist()
}
} else {
@@ -267,11 +267,11 @@ class ALS private (
// perform ALS update
logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks,
- userPartitioner, rank, lambda, alpha, YtY = None)
+ rank, lambda, alpha, YtY = None)
products.setName(s"products-$iter")
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks,
- productPartitioner, rank, lambda, alpha, YtY = None)
+ rank, lambda, alpha, YtY = None)
users.setName(s"users-$iter")
}
}
@@ -464,7 +464,6 @@ class ALS private (
products: RDD[(Int, Array[Array[Double]])],
productOutLinks: RDD[(Int, OutLinkBlock)],
userInLinks: RDD[(Int, InLinkBlock)],
- productPartitioner: Partitioner,
rank: Int,
lambda: Double,
alpha: Double,
@@ -477,7 +476,7 @@ class ALS private (
}
}
toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) }
- }.groupByKey(productPartitioner)
+ }.groupByKey(new HashPartitioner(numUserBlocks))
.join(userInLinks)
.mapValues{ case (messages, inLinkBlock) =>
updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY)