aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2014-12-29 13:24:26 -0800
committerXiangrui Meng <meng@databricks.com>2014-12-29 13:24:26 -0800
commit02b55de3dce9a1fef806be13e5cefa0f39ea2fcc (patch)
tree0e3e2a60779921eddf03e776442faa4ec5cc9ffa /mllib/src/test/java
parent8d72341ab75a7fb138b056cfb4e21db42aca55fb (diff)
downloadspark-02b55de3dce9a1fef806be13e5cefa0f39ea2fcc.tar.gz
spark-02b55de3dce9a1fef806be13e5cefa0f39ea2fcc.tar.bz2
spark-02b55de3dce9a1fef806be13e5cefa0f39ea2fcc.zip
[SPARK-4409][MLlib] Additional Linear Algebra Utils
Addition of a very limited number of local matrix manipulation and generation methods that would be helpful in the further development for algorithms on top of BlockMatrix (SPARK-3974), such as Randomized SVD, and Multi Model Training (SPARK-1486). The proposed methods for addition are: For `Matrix` - map: maps the values in the matrix with a given function. Produces a new matrix. - update: the values in the matrix are updated with a given function. Occurs in place. Factory methods for `DenseMatrix`: - *zeros: Generate a matrix consisting of zeros - *ones: Generate a matrix consisting of ones - *eye: Generate an identity matrix - *rand: Generate a matrix consisting of i.i.d. uniform random numbers - *randn: Generate a matrix consisting of i.i.d. gaussian random numbers - *diag: Generate a diagonal matrix from a supplied vector *These methods already exist in the factory methods for `Matrices`, however for cases where we require a `DenseMatrix`, you constantly have to add `.asInstanceOf[DenseMatrix]` everywhere, which makes the code "dirtier". I propose moving these functions to factory methods for `DenseMatrix` where the putput will be a `DenseMatrix` and the factory methods for `Matrices` will call these functions directly and output a generic `Matrix`. Factory methods for `SparseMatrix`: - speye: Identity matrix in sparse format. Saves a ton of memory when dimensions are large, especially in Multi Model Training, where each row requires being multiplied by a scalar. - sprand: Generate a sparse matrix with a given density consisting of i.i.d. uniform random numbers. - sprandn: Generate a sparse matrix with a given density consisting of i.i.d. gaussian random numbers. - diag: Generate a diagonal matrix from a supplied vector, but is memory efficient, because it just stores the diagonal. Again, very helpful in Multi Model Training. Factory methods for `Matrices`: - Include all the factory methods given above, but return a generic `Matrix` rather than `SparseMatrix` or `DenseMatrix`. - horzCat: Horizontally concatenate matrices to form one larger matrix. Very useful in both Multi Model Training, and for the repartitioning of BlockMatrix. - vertCat: Vertically concatenate matrices to form one larger matrix. Very useful for the repartitioning of BlockMatrix. The names for these methods were selected from MATLAB Author: Burak Yavuz <brkyvz@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #3319 from brkyvz/SPARK-4409 and squashes the following commits: b0354f6 [Burak Yavuz] [SPARK-4409] Incorporated mengxr's code 04c4829 [Burak Yavuz] Merge pull request #1 from mengxr/SPARK-4409 80cfa29 [Xiangrui Meng] minor changes ecc937a [Xiangrui Meng] update sprand 4e95e24 [Xiangrui Meng] simplify fromCOO implementation 10a63a6 [Burak Yavuz] [SPARK-4409] Fourth pass of code review f62d6c7 [Burak Yavuz] [SPARK-4409] Modified genRandMatrix 3971c93 [Burak Yavuz] [SPARK-4409] Third pass of code review 75239f8 [Burak Yavuz] [SPARK-4409] Second pass of code review e4bd0c0 [Burak Yavuz] [SPARK-4409] Modified horzcat and vertcat 65c562e [Burak Yavuz] [SPARK-4409] Hopefully fixed Java Test d8be7bc [Burak Yavuz] [SPARK-4409] Organized imports 065b531 [Burak Yavuz] [SPARK-4409] First pass after code review a8120d2 [Burak Yavuz] [SPARK-4409] Finished updates to API according to SPARK-4614 f798c82 [Burak Yavuz] [SPARK-4409] Updated API according to SPARK-4614 c75f3cd [Burak Yavuz] [SPARK-4409] Added JavaAPI Tests, and fixed a couple of bugs d662f9d [Burak Yavuz] [SPARK-4409] Modified according to remote repo 83dfe37 [Burak Yavuz] [SPARK-4409] Scalastyle error fixed a14c0da [Burak Yavuz] [SPARK-4409] Initial commit to add methods
Diffstat (limited to 'mllib/src/test/java')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java163
1 files changed, 163 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
new file mode 100644
index 0000000000..704d484d0b
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg;
+
+import static org.junit.Assert.*;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.util.Random;
+
+public class JavaMatricesSuite implements Serializable {
+
+ @Test
+ public void randMatrixConstruction() {
+ Random rng = new Random(24);
+ Matrix r = Matrices.rand(3, 4, rng);
+ rng.setSeed(24);
+ DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
+ assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
+
+ rng.setSeed(24);
+ Matrix rn = Matrices.randn(3, 4, rng);
+ rng.setSeed(24);
+ DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
+ assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
+
+ rng.setSeed(24);
+ Matrix s = Matrices.sprand(3, 4, 0.5, rng);
+ rng.setSeed(24);
+ SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
+ assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
+
+ rng.setSeed(24);
+ Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
+ rng.setSeed(24);
+ SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
+ assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
+ }
+
+ @Test
+ public void identityMatrixConstruction() {
+ Matrix r = Matrices.eye(2);
+ DenseMatrix dr = DenseMatrix.eye(2);
+ SparseMatrix sr = SparseMatrix.speye(2);
+ assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
+ assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
+ assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
+ }
+
+ @Test
+ public void diagonalMatrixConstruction() {
+ Vector v = Vectors.dense(1.0, 0.0, 2.0);
+ Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
+
+ Matrix m = Matrices.diag(v);
+ Matrix sm = Matrices.diag(sv);
+ DenseMatrix d = DenseMatrix.diag(v);
+ DenseMatrix sd = DenseMatrix.diag(sv);
+ SparseMatrix s = SparseMatrix.diag(v);
+ SparseMatrix ss = SparseMatrix.diag(sv);
+
+ assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
+ assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
+ assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
+ assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
+ assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
+ assertArrayEquals(s.values(), ss.values(), 0.0);
+ assert(s.values().length == 2);
+ assert(ss.values().length == 2);
+ assert(s.colPtrs().length == 4);
+ assert(ss.colPtrs().length == 4);
+ }
+
+ @Test
+ public void zerosMatrixConstruction() {
+ Matrix z = Matrices.zeros(2, 2);
+ Matrix one = Matrices.ones(2, 2);
+ DenseMatrix dz = DenseMatrix.zeros(2, 2);
+ DenseMatrix done = DenseMatrix.ones(2, 2);
+
+ assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
+ assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
+ assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
+ assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
+ }
+
+ @Test
+ public void sparseDenseConversion() {
+ int m = 3;
+ int n = 2;
+ double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
+ double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
+ int[] colPtrs = new int[]{0, 2, 4};
+ int[] rowIndices = new int[]{0, 1, 1, 2};
+
+ SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
+ DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
+
+ SparseMatrix spMat2 = deMat1.toSparse();
+ DenseMatrix deMat2 = spMat1.toDense();
+
+ assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
+ assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
+ }
+
+ @Test
+ public void concatenateMatrices() {
+ int m = 3;
+ int n = 2;
+
+ Random rng = new Random(42);
+ SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
+ rng.setSeed(42);
+ DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
+ Matrix deMat2 = Matrices.eye(3);
+ Matrix spMat2 = Matrices.speye(3);
+ Matrix deMat3 = Matrices.eye(2);
+ Matrix spMat3 = Matrices.speye(2);
+
+ Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
+ Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
+ Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
+ Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
+
+ assert(deHorz1.numRows() == 3);
+ assert(deHorz2.numRows() == 3);
+ assert(deHorz3.numRows() == 3);
+ assert(spHorz.numRows() == 3);
+ assert(deHorz1.numCols() == 5);
+ assert(deHorz2.numCols() == 5);
+ assert(deHorz3.numCols() == 5);
+ assert(spHorz.numCols() == 5);
+
+ Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
+ Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
+ Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
+ Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
+
+ assert(deVert1.numRows() == 5);
+ assert(deVert2.numRows() == 5);
+ assert(deVert3.numRows() == 5);
+ assert(spVert.numRows() == 5);
+ assert(deVert1.numCols() == 2);
+ assert(deVert2.numCols() == 2);
+ assert(deVert3.numCols() == 2);
+ assert(spVert.numCols() == 2);
+ }
+}