aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorBenFradet <benjamin.fradet@gmail.com>2016-02-16 13:03:28 +0000
committerSean Owen <sowen@cloudera.com>2016-02-16 13:03:28 +0000
commit00c72d27bf2e3591c4068fb344fa3edf1662ad81 (patch)
treeb32ed039fd5f4e3775622a9918173df53b943e30 /examples
parent827ed1c06785692d14857bd41f1fd94a0853874a (diff)
downloadspark-00c72d27bf2e3591c4068fb344fa3edf1662ad81.tar.gz
spark-00c72d27bf2e3591c4068fb344fa3edf1662ad81.tar.bz2
spark-00c72d27bf2e3591c4068fb344fa3edf1662ad81.zip
[SPARK-12247][ML][DOC] Documentation for spark.ml's ALS and collaborative filtering in general
This documents the implementation of ALS in `spark.ml` with example code in scala, java and python. Author: BenFradet <benjamin.fradet@gmail.com> Closes #10411 from BenFradet/SPARK-12247.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java125
-rw-r--r--examples/src/main/python/ml/als_example.py57
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala82
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala182
4 files changed, 264 insertions, 182 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
new file mode 100644
index 0000000000..90d2ac2b13
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+
+// $example on$
+import java.io.Serializable;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.ml.evaluation.RegressionEvaluator;
+import org.apache.spark.ml.recommendation.ALS;
+import org.apache.spark.ml.recommendation.ALSModel;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.types.DataTypes;
+// $example off$
+
+public class JavaALSExample {
+
+ // $example on$
+ public static class Rating implements Serializable {
+ private int userId;
+ private int movieId;
+ private float rating;
+ private long timestamp;
+
+ public Rating() {}
+
+ public Rating(int userId, int movieId, float rating, long timestamp) {
+ this.userId = userId;
+ this.movieId = movieId;
+ this.rating = rating;
+ this.timestamp = timestamp;
+ }
+
+ public int getUserId() {
+ return userId;
+ }
+
+ public int getMovieId() {
+ return movieId;
+ }
+
+ public float getRating() {
+ return rating;
+ }
+
+ public long getTimestamp() {
+ return timestamp;
+ }
+
+ public static Rating parseRating(String str) {
+ String[] fields = str.split("::");
+ if (fields.length != 4) {
+ throw new IllegalArgumentException("Each line must contain 4 fields");
+ }
+ int userId = Integer.parseInt(fields[0]);
+ int movieId = Integer.parseInt(fields[1]);
+ float rating = Float.parseFloat(fields[2]);
+ long timestamp = Long.parseLong(fields[3]);
+ return new Rating(userId, movieId, rating, timestamp);
+ }
+ }
+ // $example off$
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaALSExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext sqlContext = new SQLContext(jsc);
+
+ // $example on$
+ JavaRDD<Rating> ratingsRDD = jsc.textFile("data/mllib/als/sample_movielens_ratings.txt")
+ .map(new Function<String, Rating>() {
+ public Rating call(String str) {
+ return Rating.parseRating(str);
+ }
+ });
+ DataFrame ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
+ DataFrame[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
+ DataFrame training = splits[0];
+ DataFrame test = splits[1];
+
+ // Build the recommendation model using ALS on the training data
+ ALS als = new ALS()
+ .setMaxIter(5)
+ .setRegParam(0.01)
+ .setUserCol("userId")
+ .setItemCol("movieId")
+ .setRatingCol("rating");
+ ALSModel model = als.fit(training);
+
+ // Evaluate the model by computing the RMSE on the test data
+ DataFrame rawPredictions = model.transform(test);
+ DataFrame predictions = rawPredictions
+ .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
+ .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));
+
+ RegressionEvaluator evaluator = new RegressionEvaluator()
+ .setMetricName("rmse")
+ .setLabelCol("rating")
+ .setPredictionCol("prediction");
+ Double rmse = evaluator.evaluate(predictions);
+ System.out.println("Root-mean-square error = " + rmse);
+ // $example off$
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py
new file mode 100644
index 0000000000..f61c8ab5d6
--- /dev/null
+++ b/examples/src/main/python/ml/als_example.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+
+# $example on$
+import math
+
+from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.recommendation import ALS
+from pyspark.sql import Row
+# $example off$
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="ALSExample")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ lines = sc.textFile("data/mllib/als/sample_movielens_ratings.txt")
+ parts = lines.map(lambda l: l.split("::"))
+ ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]),
+ rating=float(p[2]), timestamp=long(p[3])))
+ ratings = sqlContext.createDataFrame(ratingsRDD)
+ (training, test) = ratings.randomSplit([0.8, 0.2])
+
+ # Build the recommendation model using ALS on the training data
+ als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating")
+ model = als.fit(training)
+
+ # Evaluate the model by computing the RMSE on the test data
+ rawPredictions = model.transform(test)
+ predictions = rawPredictions\
+ .withColumn("rating", rawPredictions.rating.cast("double"))\
+ .withColumn("prediction", rawPredictions.prediction.cast("double"))
+ evaluator =\
+ RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
+ rmse = evaluator.evaluate(predictions)
+ print("Root-mean-square error = " + str(rmse))
+ # $example off$
+ sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala
new file mode 100644
index 0000000000..a79e15c767
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+import org.apache.spark.{SparkConf, SparkContext}
+// $example on$
+import org.apache.spark.ml.evaluation.RegressionEvaluator
+import org.apache.spark.ml.recommendation.ALS
+// $example off$
+import org.apache.spark.sql.SQLContext
+// $example on$
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
+// $example off$
+
+object ALSExample {
+
+ // $example on$
+ case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
+ object Rating {
+ def parseRating(str: String): Rating = {
+ val fields = str.split("::")
+ assert(fields.size == 4)
+ Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
+ }
+ }
+ // $example off$
+
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("ALSExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // $example on$
+ val ratings = sc.textFile("data/mllib/als/sample_movielens_ratings.txt")
+ .map(Rating.parseRating)
+ .toDF()
+ val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))
+
+ // Build the recommendation model using ALS on the training data
+ val als = new ALS()
+ .setMaxIter(5)
+ .setRegParam(0.01)
+ .setUserCol("userId")
+ .setItemCol("movieId")
+ .setRatingCol("rating")
+ val model = als.fit(training)
+
+ // Evaluate the model by computing the RMSE on the test data
+ val predictions = model.transform(test)
+ .withColumn("rating", col("rating").cast(DoubleType))
+ .withColumn("prediction", col("prediction").cast(DoubleType))
+
+ val evaluator = new RegressionEvaluator()
+ .setMetricName("rmse")
+ .setLabelCol("rating")
+ .setPredictionCol("prediction")
+ val rmse = evaluator.evaluate(predictions)
+ println(s"Root-mean-square error = $rmse")
+ // $example off$
+ sc.stop()
+ }
+}
+// scalastyle:on println
+
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
deleted file mode 100644
index 02ed746954..0000000000
--- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
+++ /dev/null
@@ -1,182 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// scalastyle:off println
-package org.apache.spark.examples.ml
-
-import scopt.OptionParser
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.examples.mllib.AbstractParams
-import org.apache.spark.ml.recommendation.ALS
-import org.apache.spark.sql.{Row, SQLContext}
-
-/**
- * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
- * Run with
- * {{{
- * bin/run-example ml.MovieLensALS
- * }}}
- */
-object MovieLensALS {
-
- case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
-
- object Rating {
- def parseRating(str: String): Rating = {
- val fields = str.split("::")
- assert(fields.size == 4)
- Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
- }
- }
-
- case class Movie(movieId: Int, title: String, genres: Seq[String])
-
- object Movie {
- def parseMovie(str: String): Movie = {
- val fields = str.split("::")
- assert(fields.size == 3)
- Movie(fields(0).toInt, fields(1), fields(2).split("\\|"))
- }
- }
-
- case class Params(
- ratings: String = null,
- movies: String = null,
- maxIter: Int = 10,
- regParam: Double = 0.1,
- rank: Int = 10,
- numBlocks: Int = 10) extends AbstractParams[Params]
-
- def main(args: Array[String]) {
- val defaultParams = Params()
-
- val parser = new OptionParser[Params]("MovieLensALS") {
- head("MovieLensALS: an example app for ALS on MovieLens data.")
- opt[String]("ratings")
- .required()
- .text("path to a MovieLens dataset of ratings")
- .action((x, c) => c.copy(ratings = x))
- opt[String]("movies")
- .required()
- .text("path to a MovieLens dataset of movies")
- .action((x, c) => c.copy(movies = x))
- opt[Int]("rank")
- .text(s"rank, default: ${defaultParams.rank}")
- .action((x, c) => c.copy(rank = x))
- opt[Int]("maxIter")
- .text(s"max number of iterations, default: ${defaultParams.maxIter}")
- .action((x, c) => c.copy(maxIter = x))
- opt[Double]("regParam")
- .text(s"regularization parameter, default: ${defaultParams.regParam}")
- .action((x, c) => c.copy(regParam = x))
- opt[Int]("numBlocks")
- .text(s"number of blocks, default: ${defaultParams.numBlocks}")
- .action((x, c) => c.copy(numBlocks = x))
- note(
- """
- |Example command line to run this app:
- |
- | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \
- | examples/target/scala-*/spark-examples-*.jar \
- | --rank 10 --maxIter 15 --regParam 0.1 \
- | --movies data/mllib/als/sample_movielens_movies.txt \
- | --ratings data/mllib/als/sample_movielens_ratings.txt
- """.stripMargin)
- }
-
- parser.parse(args, defaultParams).map { params =>
- run(params)
- } getOrElse {
- System.exit(1)
- }
- }
-
- def run(params: Params) {
- val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
- val sc = new SparkContext(conf)
- val sqlContext = new SQLContext(sc)
- import sqlContext.implicits._
-
- val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache()
-
- val numRatings = ratings.count()
- val numUsers = ratings.map(_.userId).distinct().count()
- val numMovies = ratings.map(_.movieId).distinct().count()
-
- println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
-
- val splits = ratings.randomSplit(Array(0.8, 0.2), 0L)
- val training = splits(0).cache()
- val test = splits(1).cache()
-
- val numTraining = training.count()
- val numTest = test.count()
- println(s"Training: $numTraining, test: $numTest.")
-
- ratings.unpersist(blocking = false)
-
- val als = new ALS()
- .setUserCol("userId")
- .setItemCol("movieId")
- .setRank(params.rank)
- .setMaxIter(params.maxIter)
- .setRegParam(params.regParam)
- .setNumBlocks(params.numBlocks)
-
- val model = als.fit(training.toDF())
-
- val predictions = model.transform(test.toDF()).cache()
-
- // Evaluate the model.
- // TODO: Create an evaluator to compute RMSE.
- val mse = predictions.select("rating", "prediction").rdd
- .flatMap { case Row(rating: Float, prediction: Float) =>
- val err = rating.toDouble - prediction
- val err2 = err * err
- if (err2.isNaN) {
- None
- } else {
- Some(err2)
- }
- }.mean()
- val rmse = math.sqrt(mse)
- println(s"Test RMSE = $rmse.")
-
- // Inspect false positives.
- // Note: We reference columns in 2 ways:
- // (1) predictions("movieId") lets us specify the movieId column in the predictions
- // DataFrame, rather than the movieId column in the movies DataFrame.
- // (2) $"userId" specifies the userId column in the predictions DataFrame.
- // We could also write predictions("userId") but do not have to since
- // the movies DataFrame does not have a column "userId."
- val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF()
- val falsePositives = predictions.join(movies)
- .where((predictions("movieId") === movies("movieId"))
- && ($"rating" <= 1) && ($"prediction" >= 4))
- .select($"userId", predictions("movieId"), $"title", $"rating", $"prediction")
- val numFalsePositives = falsePositives.count()
- println(s"Found $numFalsePositives false positives")
- if (numFalsePositives > 0) {
- println(s"Example false positives:")
- falsePositives.limit(100).collect().foreach(println)
- }
-
- sc.stop()
- }
-}
-// scalastyle:on println