diff options
author | Matei Zaharia <matei.zaharia@gmail.com> | 2013-07-15 00:30:10 +0000 |
---|---|---|
committer | Matei Zaharia <matei.zaharia@gmail.com> | 2013-07-15 00:30:10 +0000 |
commit | ed7fd501cf7ece730cbdee6c152b917cf6bfb16a (patch) | |
tree | 23ece1875975f9eb4e92287110a5157ce475036e | |
parent | 10c05937bdf59abee2b9c6c3ee45ea747c5f6ee4 (diff) | |
download | spark-ed7fd501cf7ece730cbdee6c152b917cf6bfb16a.tar.gz spark-ed7fd501cf7ece730cbdee6c152b917cf6bfb16a.tar.bz2 spark-ed7fd501cf7ece730cbdee6c152b917cf6bfb16a.zip |
Make number of blocks in ALS configurable and lower the default
-rw-r--r-- | mllib/src/main/scala/spark/mllib/recommendation/ALS.scala | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala index 21eb21276e..2abaf2f2dd 100644 --- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala @@ -91,7 +91,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l */ def train(ratings: RDD[(Int, Int, Double)]): MatrixFactorizationModel = { val numBlocks = if (this.numBlocks == -1) { - math.max(ratings.context.defaultParallelism, ratings.partitions.size) + math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2) } else { this.numBlocks } @@ -384,12 +384,13 @@ object ALS { } def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir>") + if (args.length != 5 && args.length != 6) { + println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]") System.exit(1) } val (master, ratingsFile, rank, iters, outputDir) = (args(0), args(1), args(2).toInt, args(3).toInt, args(4)) + val blocks = if (args.length == 6) args(5).toInt else -1 System.setProperty("spark.serializer", "spark.KryoSerializer") System.setProperty("spark.locality.wait", "10000") val sc = new SparkContext(master, "ALS") @@ -397,7 +398,7 @@ object ALS { val fields = line.split(',') (fields(0).toInt, fields(1).toInt, fields(2).toDouble) } - val model = ALS.train(ratings, rank, iters) + val model = ALS.train(ratings, rank, iters, 0.01, blocks) model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } .saveAsTextFile(outputDir + "/userFeatures") model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } |