diff options
author | Evan Sparks <evan.sparks@gmail.com> | 2014-05-08 00:24:36 -0400 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-05-08 00:24:36 -0400 |
commit | 6ed7e2cd01955adfbb3960e2986b6d19eaee8717 (patch) | |
tree | 001585b295f006c0b93f0d21b7827b544df7bcd3 /examples/src | |
parent | 108c4c16cc82af2e161d569d2c23849bdbf4aadb (diff) | |
download | spark-6ed7e2cd01955adfbb3960e2986b6d19eaee8717.tar.gz spark-6ed7e2cd01955adfbb3960e2986b6d19eaee8717.tar.bz2 spark-6ed7e2cd01955adfbb3960e2986b6d19eaee8717.zip |
Use numpy directly for matrix multiply.
Using matrix multiply to compute XtX and XtY yields a 5-20x speedup depending on problem size.
For example - the following takes 19s locally after this change vs. 5m21s before the change. (16x speedup).
bin/pyspark examples/src/main/python/als.py local[8] 1000 1000 50 10 10
Author: Evan Sparks <evan.sparks@gmail.com>
Closes #687 from etrain/patch-1 and squashes the following commits:
e094dbc [Evan Sparks] Touching only diaganols on update.
d1ab9b6 [Evan Sparks] Use numpy directly for matrix multiply.
Diffstat (limited to 'examples/src')
-rwxr-xr-x | examples/src/main/python/als.py | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index a77dfb2577..33700ab4f8 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,14 +36,13 @@ def rmse(R, ms, us): def update(i, vec, mat, ratings): uu = mat.shape[0] ff = mat.shape[1] - XtX = matrix(np.zeros((ff, ff))) - Xty = np.zeros((ff, 1)) - - for j in range(uu): - v = mat[j, :] - XtX += v.T * v - Xty += v.T * ratings[i, j] - XtX += np.eye(ff, ff) * LAMBDA * uu + + XtX = mat.T * mat + XtY = mat.T * ratings[i, :].T + + for j in range(ff): + XtX[j,j] += LAMBDA * uu + return np.linalg.solve(XtX, Xty) if __name__ == "__main__": |