aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/als.py
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-05-25 14:48:27 -0700
committerReynold Xin <rxin@apache.org>2014-05-25 14:48:27 -0700
commitd79c2b28e17ec0b15198aaedd2e1f403d81f717e (patch)
tree1917d4285692d387de250f8ee8192f794bb2966c /examples/src/main/python/als.py
parent55fddf9cc0fe420d5396b0e730c8413b2f23d636 (diff)
downloadspark-d79c2b28e17ec0b15198aaedd2e1f403d81f717e.tar.gz
spark-d79c2b28e17ec0b15198aaedd2e1f403d81f717e.tar.bz2
spark-d79c2b28e17ec0b15198aaedd2e1f403d81f717e.zip
Fix PEP8 violations in examples/src/main/python.
Author: Reynold Xin <rxin@apache.org> Closes #870 from rxin/examples-python-pep8 and squashes the following commits: 2829e84 [Reynold Xin] Fix PEP8 violations in examples/src/main/python.
Diffstat (limited to 'examples/src/main/python/als.py')
-rwxr-xr-xexamples/src/main/python/als.py20
1 files changed, 12 insertions, 8 deletions
diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py
index f0b46cd28b..1a7c4c51f4 100755
--- a/examples/src/main/python/als.py
+++ b/examples/src/main/python/als.py
@@ -29,22 +29,25 @@ from pyspark import SparkContext
LAMBDA = 0.01 # regularization
np.random.seed(42)
+
def rmse(R, ms, us):
diff = R - ms * us.T
return np.sqrt(np.sum(np.power(diff, 2)) / M * U)
+
def update(i, vec, mat, ratings):
uu = mat.shape[0]
ff = mat.shape[1]
-
+
XtX = mat.T * mat
Xty = mat.T * ratings[i, :].T
-
+
for j in range(ff):
- XtX[j,j] += LAMBDA * uu
-
+ XtX[j, j] += LAMBDA * uu
+
return np.linalg.solve(XtX, Xty)
+
if __name__ == "__main__":
"""
Usage: als [M] [U] [F] [iterations] [slices]"
@@ -57,10 +60,10 @@ if __name__ == "__main__":
slices = int(sys.argv[5]) if len(sys.argv) > 5 else 2
print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \
- (M, U, F, ITERATIONS, slices)
+ (M, U, F, ITERATIONS, slices)
R = matrix(rand(M, F)) * matrix(rand(U, F).T)
- ms = matrix(rand(M ,F))
+ ms = matrix(rand(M, F))
us = matrix(rand(U, F))
Rb = sc.broadcast(R)
@@ -71,8 +74,9 @@ if __name__ == "__main__":
ms = sc.parallelize(range(M), slices) \
.map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \
.collect()
- ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being
- # a 3-d array, we take the first 2 dims for the matrix
+ # collect() returns a list, so array ends up being
+ # a 3-d array, we take the first 2 dims for the matrix
+ ms = matrix(np.array(ms)[:, :, 0])
msb = sc.broadcast(ms)
us = sc.parallelize(range(U), slices) \