aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-11 22:33:45 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-11 22:33:45 -0700
commit9038d94e1e50e05de00fd51af4fd7b9280481cdc (patch)
tree9ea1513ed5cea9d8f758f40678f4e4c37291a224 /mllib/src/test
parent5d54d71ddbac1fbb26925a8c9138bbb8c0e81db8 (diff)
downloadspark-9038d94e1e50e05de00fd51af4fd7b9280481cdc.tar.gz
spark-9038d94e1e50e05de00fd51af4fd7b9280481cdc.tar.bz2
spark-9038d94e1e50e05de00fd51af4fd7b9280481cdc.zip
[SPARK-2923][MLLIB] Implement some basic BLAS routines
Having some basic BLAS operations implemented in MLlib can help simplify the current implementation and improve some performance. Tested on my local machine: ~~~ bin/spark-submit --class org.apache.spark.examples.mllib.BinaryClassification \ examples/target/scala-*/spark-examples-*.jar --algorithm LR --regType L2 \ --regParam 1.0 --numIterations 1000 ~/share/data/rcv1.binary/rcv1_train.binary ~~~ 1. before: ~1m 2. after: ~30s CC: jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #1849 from mengxr/ml-blas and squashes the following commits: ba583a2 [Xiangrui Meng] exclude Vector.copy a4d7d2f [Xiangrui Meng] Merge branch 'master' into ml-blas 6edeab9 [Xiangrui Meng] address comments 940bdeb [Xiangrui Meng] rename MLlibBLAS to BLAS c2a38bc [Xiangrui Meng] enhance dot tests 4cfaac4 [Xiangrui Meng] add apache header 48d01d2 [Xiangrui Meng] add tests for zeros and copy 3b882b1 [Xiangrui Meng] use blas.scal in gradient 735eb23 [Xiangrui Meng] remove d from BLAS routines d2d7d3c [Xiangrui Meng] update gradient and lbfgs 7f78186 [Xiangrui Meng] add zeros to Vectors; add dscal and dcopy to BLAS 14e6645 [Xiangrui Meng] add ddot cbb8273 [Xiangrui Meng] add daxpy test 07db0bb [Xiangrui Meng] Merge branch 'master' into ml-blas e8c326d [Xiangrui Meng] axpy
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala129
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala30
2 files changed, 159 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
new file mode 100644
index 0000000000..1952e6734e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.linalg.BLAS._
+
+class BLASSuite extends FunSuite {
+
+ test("copy") {
+ val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0))
+ val dx = Vectors.dense(1.0, 0.0, -2.0, 0.0)
+ val sy = Vectors.sparse(4, Array(0, 1, 3), Array(2.0, 1.0, 1.0))
+ val dy = Array(2.0, 1.0, 0.0, 1.0)
+
+ val dy1 = Vectors.dense(dy.clone())
+ copy(sx, dy1)
+ assert(dy1 ~== dx absTol 1e-15)
+
+ val dy2 = Vectors.dense(dy.clone())
+ copy(dx, dy2)
+ assert(dy2 ~== dx absTol 1e-15)
+
+ intercept[IllegalArgumentException] {
+ copy(sx, sy)
+ }
+
+ intercept[IllegalArgumentException] {
+ copy(dx, sy)
+ }
+
+ withClue("vector sizes must match") {
+ intercept[Exception] {
+ copy(sx, Vectors.dense(0.0, 1.0, 2.0))
+ }
+ }
+ }
+
+ test("scal") {
+ val a = 0.1
+ val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0))
+ val dx = Vectors.dense(1.0, 0.0, -2.0)
+
+ scal(a, sx)
+ assert(sx ~== Vectors.sparse(3, Array(0, 2), Array(0.1, -0.2)) absTol 1e-15)
+
+ scal(a, dx)
+ assert(dx ~== Vectors.dense(0.1, 0.0, -0.2) absTol 1e-15)
+ }
+
+ test("axpy") {
+ val alpha = 0.1
+ val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0))
+ val dx = Vectors.dense(1.0, 0.0, -2.0)
+ val dy = Array(2.0, 1.0, 0.0)
+ val expected = Vectors.dense(2.1, 1.0, -0.2)
+
+ val dy1 = Vectors.dense(dy.clone())
+ axpy(alpha, sx, dy1)
+ assert(dy1 ~== expected absTol 1e-15)
+
+ val dy2 = Vectors.dense(dy.clone())
+ axpy(alpha, dx, dy2)
+ assert(dy2 ~== expected absTol 1e-15)
+
+ val sy = Vectors.sparse(4, Array(0, 1), Array(2.0, 1.0))
+
+ intercept[IllegalArgumentException] {
+ axpy(alpha, sx, sy)
+ }
+
+ intercept[IllegalArgumentException] {
+ axpy(alpha, dx, sy)
+ }
+
+ withClue("vector sizes must match") {
+ intercept[Exception] {
+ axpy(alpha, sx, Vectors.dense(1.0, 2.0))
+ }
+ }
+ }
+
+ test("dot") {
+ val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0))
+ val dx = Vectors.dense(1.0, 0.0, -2.0)
+ val sy = Vectors.sparse(3, Array(0, 1), Array(2.0, 1.0))
+ val dy = Vectors.dense(2.0, 1.0, 0.0)
+
+ assert(dot(sx, sy) ~== 2.0 absTol 1e-15)
+ assert(dot(sy, sx) ~== 2.0 absTol 1e-15)
+ assert(dot(sx, dy) ~== 2.0 absTol 1e-15)
+ assert(dot(dy, sx) ~== 2.0 absTol 1e-15)
+ assert(dot(dx, dy) ~== 2.0 absTol 1e-15)
+ assert(dot(dy, dx) ~== 2.0 absTol 1e-15)
+
+ assert(dot(sx, sx) ~== 5.0 absTol 1e-15)
+ assert(dot(dx, dx) ~== 5.0 absTol 1e-15)
+ assert(dot(sx, dx) ~== 5.0 absTol 1e-15)
+ assert(dot(dx, sx) ~== 5.0 absTol 1e-15)
+
+ val sx1 = Vectors.sparse(10, Array(0, 3, 5, 7, 8), Array(1.0, 2.0, 3.0, 4.0, 5.0))
+ val sx2 = Vectors.sparse(10, Array(1, 3, 6, 7, 9), Array(1.0, 2.0, 3.0, 4.0, 5.0))
+ assert(dot(sx1, sx2) ~== 20.0 absTol 1e-15)
+ assert(dot(sx2, sx1) ~== 20.0 absTol 1e-15)
+
+ withClue("vector sizes must match") {
+ intercept[Exception] {
+ dot(sx, Vectors.dense(2.0, 1.0))
+ }
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 7972ceea1f..cd651fe2d2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -125,4 +125,34 @@ class VectorsSuite extends FunSuite {
}
}
}
+
+ test("zeros") {
+ assert(Vectors.zeros(3) === Vectors.dense(0.0, 0.0, 0.0))
+ }
+
+ test("Vector.copy") {
+ val sv = Vectors.sparse(4, Array(0, 2), Array(1.0, 2.0))
+ val svCopy = sv.copy
+ (sv, svCopy) match {
+ case (sv: SparseVector, svCopy: SparseVector) =>
+ assert(sv.size === svCopy.size)
+ assert(sv.indices === svCopy.indices)
+ assert(sv.values === svCopy.values)
+ assert(!sv.indices.eq(svCopy.indices))
+ assert(!sv.values.eq(svCopy.values))
+ case _ =>
+ throw new RuntimeException(s"copy returned ${svCopy.getClass} on ${sv.getClass}.")
+ }
+
+ val dv = Vectors.dense(1.0, 0.0, 2.0)
+ val dvCopy = dv.copy
+ (dv, dvCopy) match {
+ case (dv: DenseVector, dvCopy: DenseVector) =>
+ assert(dv.size === dvCopy.size)
+ assert(dv.values === dvCopy.values)
+ assert(!dv.values.eq(dvCopy.values))
+ case _ =>
+ throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.")
+ }
+ }
}