diff options
author | BenFradet <benjamin.fradet@gmail.com> | 2016-02-16 13:03:28 +0000 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-02-16 13:03:28 +0000 |
commit | 00c72d27bf2e3591c4068fb344fa3edf1662ad81 (patch) | |
tree | b32ed039fd5f4e3775622a9918173df53b943e30 /examples/src/main/java | |
parent | 827ed1c06785692d14857bd41f1fd94a0853874a (diff) | |
download | spark-00c72d27bf2e3591c4068fb344fa3edf1662ad81.tar.gz spark-00c72d27bf2e3591c4068fb344fa3edf1662ad81.tar.bz2 spark-00c72d27bf2e3591c4068fb344fa3edf1662ad81.zip |
[SPARK-12247][ML][DOC] Documentation for spark.ml's ALS and collaborative filtering in general
This documents the implementation of ALS in `spark.ml` with example code in scala, java and python.
Author: BenFradet <benjamin.fradet@gmail.com>
Closes #10411 from BenFradet/SPARK-12247.
Diffstat (limited to 'examples/src/main/java')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java new file mode 100644 index 0000000000..90d2ac2b13 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.io.Serializable; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.recommendation.ALS; +import org.apache.spark.ml.recommendation.ALSModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.types.DataTypes; +// $example off$ + +public class JavaALSExample { + + // $example on$ + public static class Rating implements Serializable { + private int userId; + private int movieId; + private float rating; + private long timestamp; + + public Rating() {} + + public Rating(int userId, int movieId, float rating, long timestamp) { + this.userId = userId; + this.movieId = movieId; + this.rating = rating; + this.timestamp = timestamp; + } + + public int getUserId() { + return userId; + } + + public int getMovieId() { + return movieId; + } + + public float getRating() { + return rating; + } + + public long getTimestamp() { + return timestamp; + } + + public static Rating parseRating(String str) { + String[] fields = str.split("::"); + if (fields.length != 4) { + throw new IllegalArgumentException("Each line must contain 4 fields"); + } + int userId = Integer.parseInt(fields[0]); + int movieId = Integer.parseInt(fields[1]); + float rating = Float.parseFloat(fields[2]); + long timestamp = Long.parseLong(fields[3]); + return new Rating(userId, movieId, rating, timestamp); + } + } + // $example off$ + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaALSExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD<Rating> ratingsRDD = jsc.textFile("data/mllib/als/sample_movielens_ratings.txt") + .map(new Function<String, Rating>() { + public Rating call(String str) { + return Rating.parseRating(str); + } + }); + DataFrame ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class); + DataFrame[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); + DataFrame training = splits[0]; + DataFrame test = splits[1]; + + // Build the recommendation model using ALS on the training data + ALS als = new ALS() + .setMaxIter(5) + .setRegParam(0.01) + .setUserCol("userId") + .setItemCol("movieId") + .setRatingCol("rating"); + ALSModel model = als.fit(training); + + // Evaluate the model by computing the RMSE on the test data + DataFrame rawPredictions = model.transform(test); + DataFrame predictions = rawPredictions + .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType)) + .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType)); + + RegressionEvaluator evaluator = new RegressionEvaluator() + .setMetricName("rmse") + .setLabelCol("rating") + .setPredictionCol("prediction"); + Double rmse = evaluator.evaluate(predictions); + System.out.println("Root-mean-square error = " + rmse); + // $example off$ + jsc.stop(); + } +} |