aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/als.py
diff options
context:
space:
mode:
authorSandeep <sandeep@techaddict.me>2014-05-06 17:27:52 -0700
committerMatei Zaharia <matei@databricks.com>2014-05-06 17:27:52 -0700
commita000b5c3b0438c17e9973df4832c320210c29c27 (patch)
tree446bbe902ecd6de05072357f7ef25aaeb0687b73 /examples/src/main/python/als.py
parent39b8b1489ff92697e4aeec997cdc436c7079d6f8 (diff)
downloadspark-a000b5c3b0438c17e9973df4832c320210c29c27.tar.gz
spark-a000b5c3b0438c17e9973df4832c320210c29c27.tar.bz2
spark-a000b5c3b0438c17e9973df4832c320210c29c27.zip
SPARK-1637: Clean up examples for 1.0
- [x] Move all of them into subpackages of org.apache.spark.examples (right now some are in org.apache.spark.streaming.examples, for instance, and others are in org.apache.spark.examples.mllib) - [x] Move Python examples into examples/src/main/python - [x] Update docs to reflect these changes Author: Sandeep <sandeep@techaddict.me> This patch had conflicts when merged, resolved by Committer: Matei Zaharia <matei@databricks.com> Closes #571 from techaddict/SPARK-1637 and squashes the following commits: 47ef86c [Sandeep] Changes based on Discussions on PR, removing use of RawTextHelper from examples 8ed2d3f [Sandeep] Docs Updated for changes, Change for java examples 5f96121 [Sandeep] Move Python examples into examples/src/main/python 0a8dd77 [Sandeep] Move all Scala Examples to org.apache.spark.examples (some are in org.apache.spark.streaming.examples, for instance, and others are in org.apache.spark.examples.mllib)
Diffstat (limited to 'examples/src/main/python/als.py')
-rwxr-xr-xexamples/src/main/python/als.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py
new file mode 100755
index 0000000000..a77dfb2577
--- /dev/null
+++ b/examples/src/main/python/als.py
@@ -0,0 +1,87 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+from os.path import realpath
+import sys
+
+import numpy as np
+from numpy.random import rand
+from numpy import matrix
+from pyspark import SparkContext
+
+LAMBDA = 0.01 # regularization
+np.random.seed(42)
+
+def rmse(R, ms, us):
+ diff = R - ms * us.T
+ return np.sqrt(np.sum(np.power(diff, 2)) / M * U)
+
+def update(i, vec, mat, ratings):
+ uu = mat.shape[0]
+ ff = mat.shape[1]
+ XtX = matrix(np.zeros((ff, ff)))
+ Xty = np.zeros((ff, 1))
+
+ for j in range(uu):
+ v = mat[j, :]
+ XtX += v.T * v
+ Xty += v.T * ratings[i, j]
+ XtX += np.eye(ff, ff) * LAMBDA * uu
+ return np.linalg.solve(XtX, Xty)
+
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print >> sys.stderr, "Usage: als <master> <M> <U> <F> <iters> <slices>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)])
+ M = int(sys.argv[2]) if len(sys.argv) > 2 else 100
+ U = int(sys.argv[3]) if len(sys.argv) > 3 else 500
+ F = int(sys.argv[4]) if len(sys.argv) > 4 else 10
+ ITERATIONS = int(sys.argv[5]) if len(sys.argv) > 5 else 5
+ slices = int(sys.argv[6]) if len(sys.argv) > 6 else 2
+
+ print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \
+ (M, U, F, ITERATIONS, slices)
+
+ R = matrix(rand(M, F)) * matrix(rand(U, F).T)
+ ms = matrix(rand(M ,F))
+ us = matrix(rand(U, F))
+
+ Rb = sc.broadcast(R)
+ msb = sc.broadcast(ms)
+ usb = sc.broadcast(us)
+
+ for i in range(ITERATIONS):
+ ms = sc.parallelize(range(M), slices) \
+ .map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \
+ .collect()
+ ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being
+ # a 3-d array, we take the first 2 dims for the matrix
+ msb = sc.broadcast(ms)
+
+ us = sc.parallelize(range(U), slices) \
+ .map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \
+ .collect()
+ us = matrix(np.array(us)[:, :, 0])
+ usb = sc.broadcast(us)
+
+ error = rmse(R, ms, us)
+ print "Iteration %d:" % i
+ print "\nRMSE: %5.4f\n" % error