aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-01-22 22:09:13 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-22 22:09:13 -0800
commitea74365b7c5a3ac29cae9ba66f140f1fa5e8d312 (patch)
tree88fbd65face2e0ea7545d60e614083e6a93ec5ad /examples
parente0f7fb7f9f497b34d42f9ba147197cf9ffc51607 (diff)
downloadspark-ea74365b7c5a3ac29cae9ba66f140f1fa5e8d312.tar.gz
spark-ea74365b7c5a3ac29cae9ba66f140f1fa5e8d312.tar.bz2
spark-ea74365b7c5a3ac29cae9ba66f140f1fa5e8d312.zip
[SPARK-3541][MLLIB] New ALS implementation with improved storage
This PR adds a new ALS implementation to `spark.ml` using the pipeline API, which should be able to scale to billions of ratings. Compared with the ALS under `spark.mllib`, the new implementation 1. uses the same algorithm, 2. uses float type for ratings, 3. uses primitive arrays to avoid GC, 4. sorts and compresses ratings on each block so that we can solve least squares subproblems one by one using only one normal equation instance. The following figure shows performance comparison on copies of the Amazon Reviews dataset using a 16-node (m3.2xlarge) EC2 cluster (the same setup as in http://databricks.com/blog/2014/07/23/scalable-collaborative-filtering-with-spark-mllib.html): ![als-wip](https://cloud.githubusercontent.com/assets/829644/5659447/4c4ff8e0-96c7-11e4-87a9-73c1c63d07f3.png) I keep the `spark.mllib`'s ALS untouched for easy comparison. If the new implementation works well, I'm going to match the features of the ALS under `spark.mllib` and then make it a wrapper of the new implementation, in a separate PR. TODO: - [X] Add unit tests for implicit preferences. Author: Xiangrui Meng <meng@databricks.com> Closes #3720 from mengxr/SPARK-3541 and squashes the following commits: 1b9e852 [Xiangrui Meng] fix compile 5129be9 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-3541 dd0d0e8 [Xiangrui Meng] simplify test code c627de3 [Xiangrui Meng] add tests for implicit feedback b84f41c [Xiangrui Meng] address comments a76da7b [Xiangrui Meng] update ALS tests 2a8deb3 [Xiangrui Meng] add some ALS tests 857e876 [Xiangrui Meng] add tests for rating block and encoded block d3c1ac4 [Xiangrui Meng] rename some classes for better code readability add more doc and comments 213d163 [Xiangrui Meng] org imports 771baf3 [Xiangrui Meng] chol doc update ca9ad9d [Xiangrui Meng] add unit tests for chol b4fd17c [Xiangrui Meng] add unit tests for NormalEquation d0f99d3 [Xiangrui Meng] add tests for LocalIndexEncoder 80b8e61 [Xiangrui Meng] fix imports 4937fd4 [Xiangrui Meng] update ALS example 56c253c [Xiangrui Meng] rename product to item bce8692 [Xiangrui Meng] doc for parameters and project the output columns 3f2d81a [Xiangrui Meng] add doc 1efaecf [Xiangrui Meng] add example code 8ae86b5 [Xiangrui Meng] add a working copy of the new ALS implementation
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala174
1 files changed, 174 insertions, 0 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
new file mode 100644
index 0000000000..cf62772b92
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.recommendation.ALS
+import org.apache.spark.sql.{Row, SQLContext}
+
+/**
+ * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
+ * Run with
+ * {{{
+ * bin/run-example ml.MovieLensALS
+ * }}}
+ */
+object MovieLensALS {
+
+ case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
+
+ object Rating {
+ def parseRating(str: String): Rating = {
+ val fields = str.split("::")
+ assert(fields.size == 4)
+ Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
+ }
+ }
+
+ case class Movie(movieId: Int, title: String, genres: Seq[String])
+
+ object Movie {
+ def parseMovie(str: String): Movie = {
+ val fields = str.split("::")
+ assert(fields.size == 3)
+ Movie(fields(0).toInt, fields(1), fields(2).split("|"))
+ }
+ }
+
+ case class Params(
+ ratings: String = null,
+ movies: String = null,
+ maxIter: Int = 10,
+ regParam: Double = 0.1,
+ rank: Int = 10,
+ numBlocks: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("MovieLensALS") {
+ head("MovieLensALS: an example app for ALS on MovieLens data.")
+ opt[String]("ratings")
+ .required()
+ .text("path to a MovieLens dataset of ratings")
+ .action((x, c) => c.copy(ratings = x))
+ opt[String]("movies")
+ .required()
+ .text("path to a MovieLens dataset of movies")
+ .action((x, c) => c.copy(movies = x))
+ opt[Int]("rank")
+ .text(s"rank, default: ${defaultParams.rank}}")
+ .action((x, c) => c.copy(rank = x))
+ opt[Int]("maxIter")
+ .text(s"max number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Int]("numBlocks")
+ .text(s"number of blocks, default: ${defaultParams.numBlocks}")
+ .action((x, c) => c.copy(numBlocks = x))
+ note(
+ """
+ |Example command line to run this app:
+ |
+ | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \
+ | examples/target/scala-*/spark-examples-*.jar \
+ | --rank 10 --maxIter 15 --regParam 0.1 \
+ | --movies path/to/movielens/movies.dat \
+ | --ratings path/to/movielens/ratings.dat
+ """.stripMargin)
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ } getOrElse {
+ System.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache()
+
+ val numRatings = ratings.count()
+ val numUsers = ratings.map(_.userId).distinct().count()
+ val numMovies = ratings.map(_.movieId).distinct().count()
+
+ println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
+
+ val splits = ratings.randomSplit(Array(0.8, 0.2), 0L)
+ val training = splits(0).cache()
+ val test = splits(1).cache()
+
+ val numTraining = training.count()
+ val numTest = test.count()
+ println(s"Training: $numTraining, test: $numTest.")
+
+ ratings.unpersist(blocking = false)
+
+ val als = new ALS()
+ .setUserCol("userId")
+ .setItemCol("movieId")
+ .setRank(params.rank)
+ .setMaxIter(params.maxIter)
+ .setRegParam(params.regParam)
+ .setNumBlocks(params.numBlocks)
+
+ val model = als.fit(training)
+
+ val predictions = model.transform(test).cache()
+
+ // Evaluate the model.
+ // TODO: Create an evaluator to compute RMSE.
+ val mse = predictions.select('rating, 'prediction)
+ .flatMap { case Row(rating: Float, prediction: Float) =>
+ val err = rating.toDouble - prediction
+ val err2 = err * err
+ if (err2.isNaN) {
+ None
+ } else {
+ Some(err2)
+ }
+ }.mean()
+ val rmse = math.sqrt(mse)
+ println(s"Test RMSE = $rmse.")
+
+ // Inspect false positives.
+ predictions.registerTempTable("prediction")
+ sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie")
+ sqlContext.sql(
+ """
+ |SELECT userId, prediction.movieId, title, rating, prediction
+ | FROM prediction JOIN movie ON prediction.movieId = movie.movieId
+ | WHERE rating <= 1 AND prediction >= 4
+ | LIMIT 100
+ """.stripMargin)
+ .collect()
+ .foreach(println)
+
+ sc.stop()
+ }
+}