aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMatei Zaharia <matei.zaharia@gmail.com>2013-07-15 02:54:11 +0000
committerMatei Zaharia <matei.zaharia@gmail.com>2013-07-15 02:54:11 +0000
commit4698a0d6886905ef21cbd52e108d0dcab3df12df (patch)
tree08a8d2fe4cf24f2f1d17f5bdb33315e3fd1921e5 /mllib
parentd47c16f78d5cb935bd4022c9bed8376691371682 (diff)
downloadspark-4698a0d6886905ef21cbd52e108d0dcab3df12df.tar.gz
spark-4698a0d6886905ef21cbd52e108d0dcab3df12df.tar.bz2
spark-4698a0d6886905ef21cbd52e108d0dcab3df12df.zip
Shuffle ratings in a more efficient way at start of ALS
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/recommendation/ALS.scala18
1 files changed, 14 insertions, 4 deletions
diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
index 2abaf2f2dd..4c18cbdc6b 100644
--- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
@@ -6,8 +6,10 @@ import scala.util.Sorting
import spark.{HashPartitioner, Partitioner, SparkContext, RDD}
import spark.storage.StorageLevel
+import spark.KryoRegistrator
import spark.SparkContext._
+import com.esotericsoftware.kryo.Kryo
import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
@@ -98,8 +100,8 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
val partitioner = new HashPartitioner(numBlocks)
- val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, (u, p, r)) }
- val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, (p, u, r)) }
+ val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, Rating(u, p, r)) }
+ val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, Rating(p, u, r)) }
val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock)
val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
@@ -179,12 +181,12 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid
* having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it.
*/
- private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, (Int, Int, Double))])
+ private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)])
: (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) =
{
val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
- val ratings = elements.map{case (k, t) => Rating(t._1, t._2, t._3)}.toArray
+ val ratings = elements.map{_._2}.toArray
val inLinkBlock = makeInLinkBlock(numBlocks, ratings)
val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
@@ -383,6 +385,12 @@ object ALS {
train(ratings, rank, iterations, 0.01, -1)
}
+ private class ALSRegistrator extends KryoRegistrator {
+ override def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[Rating])
+ }
+ }
+
def main(args: Array[String]) {
if (args.length != 5 && args.length != 6) {
println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]")
@@ -392,6 +400,8 @@ object ALS {
(args(0), args(1), args(2).toInt, args(3).toInt, args(4))
val blocks = if (args.length == 6) args(5).toInt else -1
System.setProperty("spark.serializer", "spark.KryoSerializer")
+ System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName)
+ System.setProperty("spark.kryo.referenceTracking", "false")
System.setProperty("spark.locality.wait", "10000")
val sc = new SparkContext(master, "ALS")
val ratings = sc.textFile(ratingsFile).map { line =>