aboutsummaryrefslogtreecommitdiff
path: root/docs
diff options
context:
space:
mode:
authorHossein Falaki <falaki@gmail.com>2014-01-06 12:43:17 -0800
committerHossein Falaki <falaki@gmail.com>2014-01-06 12:43:17 -0800
commit150089dae12bbba693db4edbfcea360b443637df (patch)
tree20be45a22545d523e9159de5cc22af2813accf65 /docs
parent8b5be0675245e206943574b8c6f6b77018b3561a (diff)
downloadspark-150089dae12bbba693db4edbfcea360b443637df.tar.gz
spark-150089dae12bbba693db4edbfcea360b443637df.tar.bz2
spark-150089dae12bbba693db4edbfcea360b443637df.zip
Added proper evaluation example for collaborative filtering and fixed typo
Diffstat (limited to 'docs')
-rw-r--r--docs/mllib-guide.md12
1 files changed, 8 insertions, 4 deletions
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index 3fd3c91e2a..5f3b676126 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -3,7 +3,7 @@ layout: global
title: Machine Learning Library (MLlib)
---
-* Table of contests
+* Table of contents
{:toc}
MLlib is a Spark implementation of some common machine learning (ML)
@@ -403,8 +403,8 @@ Errors.
## Collaborative Filtering
In the following example we load rating data. Each row consists of a user, a product and a rating.
-We use the default ALS.train() method which assumes ratings are explicit. We evaluate the recommendation
-on one example.
+We use the default ALS.train() method which assumes ratings are explicit. We evaluate the
+recommendation by measuring the Mean Squared Error of rating prediction.
{% highlight python %}
from pyspark.mllib.recommendation import ALS
@@ -418,7 +418,11 @@ ratings = data.map(lambda line: array([float(x) for x in line.split(',')]))
model = ALS.train(sc, ratings, 1, 20)
# Evaluate the model on training data
-print("predicted rating of user {0} for item {1} is {2:.6}".format(1, 2, model.predict(1, 2)))
+testdata = ratings.map(lambda p: (int(p[0]), int(p[1])))
+predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
+ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)
+MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count()
+print("Mean Squared Error = " + str(MSE))
{% endhighlight %}
If the rating matrix is derived from other source of information (i.e., it is inferred from other