aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
authorEvan Sparks <evan.sparks@gmail.com>2014-05-08 00:24:36 -0400
committerReynold Xin <rxin@apache.org>2014-05-08 00:24:36 -0400
commit6ed7e2cd01955adfbb3960e2986b6d19eaee8717 (patch)
tree001585b295f006c0b93f0d21b7827b544df7bcd3 /examples/src
parent108c4c16cc82af2e161d569d2c23849bdbf4aadb (diff)
downloadspark-6ed7e2cd01955adfbb3960e2986b6d19eaee8717.tar.gz
spark-6ed7e2cd01955adfbb3960e2986b6d19eaee8717.tar.bz2
spark-6ed7e2cd01955adfbb3960e2986b6d19eaee8717.zip
Use numpy directly for matrix multiply.
Using matrix multiply to compute XtX and XtY yields a 5-20x speedup depending on problem size. For example - the following takes 19s locally after this change vs. 5m21s before the change. (16x speedup). bin/pyspark examples/src/main/python/als.py local[8] 1000 1000 50 10 10 Author: Evan Sparks <evan.sparks@gmail.com> Closes #687 from etrain/patch-1 and squashes the following commits: e094dbc [Evan Sparks] Touching only diaganols on update. d1ab9b6 [Evan Sparks] Use numpy directly for matrix multiply.
Diffstat (limited to 'examples/src')
-rwxr-xr-xexamples/src/main/python/als.py15
1 files changed, 7 insertions, 8 deletions
diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py
index a77dfb2577..33700ab4f8 100755
--- a/examples/src/main/python/als.py
+++ b/examples/src/main/python/als.py
@@ -36,14 +36,13 @@ def rmse(R, ms, us):
def update(i, vec, mat, ratings):
uu = mat.shape[0]
ff = mat.shape[1]
- XtX = matrix(np.zeros((ff, ff)))
- Xty = np.zeros((ff, 1))
-
- for j in range(uu):
- v = mat[j, :]
- XtX += v.T * v
- Xty += v.T * ratings[i, j]
- XtX += np.eye(ff, ff) * LAMBDA * uu
+
+ XtX = mat.T * mat
+ XtY = mat.T * ratings[i, :].T
+
+ for j in range(ff):
+ XtX[j,j] += LAMBDA * uu
+
return np.linalg.solve(XtX, Xty)
if __name__ == "__main__":