aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-11 22:33:45 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-11 22:33:45 -0700
commit9038d94e1e50e05de00fd51af4fd7b9280481cdc (patch)
tree9ea1513ed5cea9d8f758f40678f4e4c37291a224 /mllib
parent5d54d71ddbac1fbb26925a8c9138bbb8c0e81db8 (diff)
downloadspark-9038d94e1e50e05de00fd51af4fd7b9280481cdc.tar.gz
spark-9038d94e1e50e05de00fd51af4fd7b9280481cdc.tar.bz2
spark-9038d94e1e50e05de00fd51af4fd7b9280481cdc.zip
[SPARK-2923][MLLIB] Implement some basic BLAS routines
Having some basic BLAS operations implemented in MLlib can help simplify the current implementation and improve some performance. Tested on my local machine: ~~~ bin/spark-submit --class org.apache.spark.examples.mllib.BinaryClassification \ examples/target/scala-*/spark-examples-*.jar --algorithm LR --regType L2 \ --regParam 1.0 --numIterations 1000 ~/share/data/rcv1.binary/rcv1_train.binary ~~~ 1. before: ~1m 2. after: ~30s CC: jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #1849 from mengxr/ml-blas and squashes the following commits: ba583a2 [Xiangrui Meng] exclude Vector.copy a4d7d2f [Xiangrui Meng] Merge branch 'master' into ml-blas 6edeab9 [Xiangrui Meng] address comments 940bdeb [Xiangrui Meng] rename MLlibBLAS to BLAS c2a38bc [Xiangrui Meng] enhance dot tests 4cfaac4 [Xiangrui Meng] add apache header 48d01d2 [Xiangrui Meng] add tests for zeros and copy 3b882b1 [Xiangrui Meng] use blas.scal in gradient 735eb23 [Xiangrui Meng] remove d from BLAS routines d2d7d3c [Xiangrui Meng] update gradient and lbfgs 7f78186 [Xiangrui Meng] add zeros to Vectors; add dscal and dcopy to BLAS 14e6645 [Xiangrui Meng] add ddot cbb8273 [Xiangrui Meng] add daxpy test 07db0bb [Xiangrui Meng] Merge branch 'master' into ml-blas e8c326d [Xiangrui Meng] axpy
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala200
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala60
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala39
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala129
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala30
6 files changed, 428 insertions, 65 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
new file mode 100644
index 0000000000..70e23033c8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
+
+/**
+ * BLAS routines for MLlib's vectors and matrices.
+ */
+private[mllib] object BLAS extends Serializable {
+
+ @transient private var _f2jBLAS: NetlibBLAS = _
+
+ // For level-1 routines, we use Java implementation.
+ private def f2jBLAS: NetlibBLAS = {
+ if (_f2jBLAS == null) {
+ _f2jBLAS = new F2jBLAS
+ }
+ _f2jBLAS
+ }
+
+ /**
+ * y += a * x
+ */
+ def axpy(a: Double, x: Vector, y: Vector): Unit = {
+ require(x.size == y.size)
+ y match {
+ case dy: DenseVector =>
+ x match {
+ case sx: SparseVector =>
+ axpy(a, sx, dy)
+ case dx: DenseVector =>
+ axpy(a, dx, dy)
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"axpy doesn't support x type ${x.getClass}.")
+ }
+ case _ =>
+ throw new IllegalArgumentException(
+ s"axpy only supports adding to a dense vector but got type ${y.getClass}.")
+ }
+ }
+
+ /**
+ * y += a * x
+ */
+ private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = {
+ val n = x.size
+ f2jBLAS.daxpy(n, a, x.values, 1, y.values, 1)
+ }
+
+ /**
+ * y += a * x
+ */
+ private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = {
+ val nnz = x.indices.size
+ if (a == 1.0) {
+ var k = 0
+ while (k < nnz) {
+ y.values(x.indices(k)) += x.values(k)
+ k += 1
+ }
+ } else {
+ var k = 0
+ while (k < nnz) {
+ y.values(x.indices(k)) += a * x.values(k)
+ k += 1
+ }
+ }
+ }
+
+ /**
+ * dot(x, y)
+ */
+ def dot(x: Vector, y: Vector): Double = {
+ require(x.size == y.size)
+ (x, y) match {
+ case (dx: DenseVector, dy: DenseVector) =>
+ dot(dx, dy)
+ case (sx: SparseVector, dy: DenseVector) =>
+ dot(sx, dy)
+ case (dx: DenseVector, sy: SparseVector) =>
+ dot(sy, dx)
+ case (sx: SparseVector, sy: SparseVector) =>
+ dot(sx, sy)
+ case _ =>
+ throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).")
+ }
+ }
+
+ /**
+ * dot(x, y)
+ */
+ private def dot(x: DenseVector, y: DenseVector): Double = {
+ val n = x.size
+ f2jBLAS.ddot(n, x.values, 1, y.values, 1)
+ }
+
+ /**
+ * dot(x, y)
+ */
+ private def dot(x: SparseVector, y: DenseVector): Double = {
+ val nnz = x.indices.size
+ var sum = 0.0
+ var k = 0
+ while (k < nnz) {
+ sum += x.values(k) * y.values(x.indices(k))
+ k += 1
+ }
+ sum
+ }
+
+ /**
+ * dot(x, y)
+ */
+ private def dot(x: SparseVector, y: SparseVector): Double = {
+ var kx = 0
+ val nnzx = x.indices.size
+ var ky = 0
+ val nnzy = y.indices.size
+ var sum = 0.0
+ // y catching x
+ while (kx < nnzx && ky < nnzy) {
+ val ix = x.indices(kx)
+ while (ky < nnzy && y.indices(ky) < ix) {
+ ky += 1
+ }
+ if (ky < nnzy && y.indices(ky) == ix) {
+ sum += x.values(kx) * y.values(ky)
+ ky += 1
+ }
+ kx += 1
+ }
+ sum
+ }
+
+ /**
+ * y = x
+ */
+ def copy(x: Vector, y: Vector): Unit = {
+ val n = y.size
+ require(x.size == n)
+ y match {
+ case dy: DenseVector =>
+ x match {
+ case sx: SparseVector =>
+ var i = 0
+ var k = 0
+ val nnz = sx.indices.size
+ while (k < nnz) {
+ val j = sx.indices(k)
+ while (i < j) {
+ dy.values(i) = 0.0
+ i += 1
+ }
+ dy.values(i) = sx.values(k)
+ i += 1
+ k += 1
+ }
+ while (i < n) {
+ dy.values(i) = 0.0
+ i += 1
+ }
+ case dx: DenseVector =>
+ Array.copy(dx.values, 0, dy.values, 0, n)
+ }
+ case _ =>
+ throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}")
+ }
+ }
+
+ /**
+ * x = a * x
+ */
+ def scal(a: Double, x: Vector): Unit = {
+ x match {
+ case sx: SparseVector =>
+ f2jBLAS.dscal(sx.values.size, a, sx.values, 1)
+ case dx: DenseVector =>
+ f2jBLAS.dscal(dx.values.size, a, dx.values, 1)
+ case _ =>
+ throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.")
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 77b3e8c714..a45781d12e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.linalg
import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
-import java.util.Arrays
+import java.util
import scala.annotation.varargs
import scala.collection.JavaConverters._
@@ -30,6 +30,8 @@ import org.apache.spark.SparkException
/**
* Represents a numeric vector, whose index type is Int and value type is Double.
+ *
+ * Note: Users should not implement this interface.
*/
trait Vector extends Serializable {
@@ -46,12 +48,12 @@ trait Vector extends Serializable {
override def equals(other: Any): Boolean = {
other match {
case v: Vector =>
- Arrays.equals(this.toArray, v.toArray)
+ util.Arrays.equals(this.toArray, v.toArray)
case _ => false
}
}
- override def hashCode(): Int = Arrays.hashCode(this.toArray)
+ override def hashCode(): Int = util.Arrays.hashCode(this.toArray)
/**
* Converts the instance to a breeze vector.
@@ -63,6 +65,13 @@ trait Vector extends Serializable {
* @param i index
*/
def apply(i: Int): Double = toBreeze(i)
+
+ /**
+ * Makes a deep copy of this vector.
+ */
+ def copy: Vector = {
+ throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
+ }
}
/**
@@ -128,6 +137,16 @@ object Vectors {
}
/**
+ * Creates a dense vector of all zeros.
+ *
+ * @param size vector size
+ * @return a zero vector
+ */
+ def zeros(size: Int): Vector = {
+ new DenseVector(new Array[Double](size))
+ }
+
+ /**
* Parses a string resulted from `Vector#toString` into
* an [[org.apache.spark.mllib.linalg.Vector]].
*/
@@ -142,7 +161,7 @@ object Vectors {
case Seq(size: Double, indices: Array[Double], values: Array[Double]) =>
Vectors.sparse(size.toInt, indices.map(_.toInt), values)
case other =>
- throw new SparkException(s"Cannot parse $other.")
+ throw new SparkException(s"Cannot parse $other.")
}
}
@@ -183,6 +202,10 @@ class DenseVector(val values: Array[Double]) extends Vector {
private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)
override def apply(i: Int) = values(i)
+
+ override def copy: DenseVector = {
+ new DenseVector(values.clone())
+ }
}
/**
@@ -213,5 +236,9 @@ class SparseVector(
data
}
+ override def copy: SparseVector = {
+ new SparseVector(size, indices.clone(), values.clone())
+ }
+
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index 9d82f011e6..fdd6716011 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -17,10 +17,9 @@
package org.apache.spark.mllib.optimization
-import breeze.linalg.{axpy => brzAxpy}
-
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
/**
* :: DeveloperApi ::
@@ -61,11 +60,10 @@ abstract class Gradient extends Serializable {
@DeveloperApi
class LogisticGradient extends Gradient {
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
- val brzData = data.toBreeze
- val brzWeights = weights.toBreeze
- val margin: Double = -1.0 * brzWeights.dot(brzData)
+ val margin = -1.0 * dot(data, weights)
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
- val gradient = brzData * gradientMultiplier
+ val gradient = data.copy
+ scal(gradientMultiplier, gradient)
val loss =
if (label > 0) {
math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p
@@ -73,7 +71,7 @@ class LogisticGradient extends Gradient {
math.log1p(math.exp(margin)) - margin
}
- (Vectors.fromBreeze(gradient), loss)
+ (gradient, loss)
}
override def compute(
@@ -81,13 +79,9 @@ class LogisticGradient extends Gradient {
label: Double,
weights: Vector,
cumGradient: Vector): Double = {
- val brzData = data.toBreeze
- val brzWeights = weights.toBreeze
- val margin: Double = -1.0 * brzWeights.dot(brzData)
+ val margin = -1.0 * dot(data, weights)
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
-
- brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze)
-
+ axpy(gradientMultiplier, data, cumGradient)
if (label > 0) {
math.log1p(math.exp(margin))
} else {
@@ -106,13 +100,11 @@ class LogisticGradient extends Gradient {
@DeveloperApi
class LeastSquaresGradient extends Gradient {
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
- val brzData = data.toBreeze
- val brzWeights = weights.toBreeze
- val diff = brzWeights.dot(brzData) - label
+ val diff = dot(data, weights) - label
val loss = diff * diff
- val gradient = brzData * (2.0 * diff)
-
- (Vectors.fromBreeze(gradient), loss)
+ val gradient = data.copy
+ scal(2.0 * diff, gradient)
+ (gradient, loss)
}
override def compute(
@@ -120,12 +112,8 @@ class LeastSquaresGradient extends Gradient {
label: Double,
weights: Vector,
cumGradient: Vector): Double = {
- val brzData = data.toBreeze
- val brzWeights = weights.toBreeze
- val diff = brzWeights.dot(brzData) - label
-
- brzAxpy(2.0 * diff, brzData, cumGradient.toBreeze)
-
+ val diff = dot(data, weights) - label
+ axpy(2.0 * diff, data, cumGradient)
diff * diff
}
}
@@ -139,18 +127,16 @@ class LeastSquaresGradient extends Gradient {
@DeveloperApi
class HingeGradient extends Gradient {
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
- val brzData = data.toBreeze
- val brzWeights = weights.toBreeze
- val dotProduct = brzWeights.dot(brzData)
-
+ val dotProduct = dot(data, weights)
// Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x)))
// Therefore the gradient is -(2y - 1)*x
val labelScaled = 2 * label - 1.0
-
if (1.0 > labelScaled * dotProduct) {
- (Vectors.fromBreeze(brzData * (-labelScaled)), 1.0 - labelScaled * dotProduct)
+ val gradient = data.copy
+ scal(-labelScaled, gradient)
+ (gradient, 1.0 - labelScaled * dotProduct)
} else {
- (Vectors.dense(new Array[Double](weights.size)), 0.0)
+ (Vectors.sparse(weights.size, Array.empty, Array.empty), 0.0)
}
}
@@ -159,16 +145,12 @@ class HingeGradient extends Gradient {
label: Double,
weights: Vector,
cumGradient: Vector): Double = {
- val brzData = data.toBreeze
- val brzWeights = weights.toBreeze
- val dotProduct = brzWeights.dot(brzData)
-
+ val dotProduct = dot(data, weights)
// Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x)))
// Therefore the gradient is -(2y - 1)*x
val labelScaled = 2 * label - 1.0
-
if (1.0 > labelScaled * dotProduct) {
- brzAxpy(-labelScaled, brzData, cumGradient.toBreeze)
+ axpy(-labelScaled, data, cumGradient)
1.0 - labelScaled * dotProduct
} else {
0.0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index 26a2b62e76..033fe44f34 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization
import scala.collection.mutable.ArrayBuffer
-import breeze.linalg.{DenseVector => BDV, axpy}
+import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS}
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS.axpy
import org.apache.spark.mllib.rdd.RDDFunctions._
+import org.apache.spark.rdd.RDD
/**
* :: DeveloperApi ::
@@ -192,31 +193,29 @@ object LBFGS extends Logging {
regParam: Double,
numExamples: Long) extends DiffFunction[BDV[Double]] {
- private var i = 0
-
- override def calculate(weights: BDV[Double]) = {
+ override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
// Have a local copy to avoid the serialization of CostFun object which is not serializable.
+ val w = Vectors.fromBreeze(weights)
+ val n = w.size
+ val bcW = data.context.broadcast(w)
val localGradient = gradient
- val n = weights.length
- val bcWeights = data.context.broadcast(weights)
- val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))(
+ val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = localGradient.compute(
- features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))
+ features, label, bcW.value, grad)
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
- (grad1 += grad2, loss1 + loss2)
+ axpy(1.0, grad2, grad1)
+ (grad1, loss1 + loss2)
})
/**
* regVal is sum of weight squares if it's L2 updater;
* for other updater, the same logic is followed.
*/
- val regVal = updater.compute(
- Vectors.fromBreeze(weights),
- Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
+ val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2
val loss = lossSum / numExamples + regVal
/**
@@ -236,17 +235,13 @@ object LBFGS extends Logging {
*/
// The following gradientTotal is actually the regularization part of gradient.
// Will add the gradientSum computed from the data with weights in the next step.
- val gradientTotal = weights - updater.compute(
- Vectors.fromBreeze(weights),
- Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze
+ val gradientTotal = w.copy
+ axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal)
// gradientTotal = gradientSum / numExamples + gradientTotal
axpy(1.0 / numExamples, gradientSum, gradientTotal)
- i += 1
-
- (loss, gradientTotal)
+ (loss, gradientTotal.toBreeze.asInstanceOf[BDV[Double]])
}
}
-
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
new file mode 100644
index 0000000000..1952e6734e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.linalg.BLAS._
+
+class BLASSuite extends FunSuite {
+
+ test("copy") {
+ val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0))
+ val dx = Vectors.dense(1.0, 0.0, -2.0, 0.0)
+ val sy = Vectors.sparse(4, Array(0, 1, 3), Array(2.0, 1.0, 1.0))
+ val dy = Array(2.0, 1.0, 0.0, 1.0)
+
+ val dy1 = Vectors.dense(dy.clone())
+ copy(sx, dy1)
+ assert(dy1 ~== dx absTol 1e-15)
+
+ val dy2 = Vectors.dense(dy.clone())
+ copy(dx, dy2)
+ assert(dy2 ~== dx absTol 1e-15)
+
+ intercept[IllegalArgumentException] {
+ copy(sx, sy)
+ }
+
+ intercept[IllegalArgumentException] {
+ copy(dx, sy)
+ }
+
+ withClue("vector sizes must match") {
+ intercept[Exception] {
+ copy(sx, Vectors.dense(0.0, 1.0, 2.0))
+ }
+ }
+ }
+
+ test("scal") {
+ val a = 0.1
+ val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0))
+ val dx = Vectors.dense(1.0, 0.0, -2.0)
+
+ scal(a, sx)
+ assert(sx ~== Vectors.sparse(3, Array(0, 2), Array(0.1, -0.2)) absTol 1e-15)
+
+ scal(a, dx)
+ assert(dx ~== Vectors.dense(0.1, 0.0, -0.2) absTol 1e-15)
+ }
+
+ test("axpy") {
+ val alpha = 0.1
+ val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0))
+ val dx = Vectors.dense(1.0, 0.0, -2.0)
+ val dy = Array(2.0, 1.0, 0.0)
+ val expected = Vectors.dense(2.1, 1.0, -0.2)
+
+ val dy1 = Vectors.dense(dy.clone())
+ axpy(alpha, sx, dy1)
+ assert(dy1 ~== expected absTol 1e-15)
+
+ val dy2 = Vectors.dense(dy.clone())
+ axpy(alpha, dx, dy2)
+ assert(dy2 ~== expected absTol 1e-15)
+
+ val sy = Vectors.sparse(4, Array(0, 1), Array(2.0, 1.0))
+
+ intercept[IllegalArgumentException] {
+ axpy(alpha, sx, sy)
+ }
+
+ intercept[IllegalArgumentException] {
+ axpy(alpha, dx, sy)
+ }
+
+ withClue("vector sizes must match") {
+ intercept[Exception] {
+ axpy(alpha, sx, Vectors.dense(1.0, 2.0))
+ }
+ }
+ }
+
+ test("dot") {
+ val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0))
+ val dx = Vectors.dense(1.0, 0.0, -2.0)
+ val sy = Vectors.sparse(3, Array(0, 1), Array(2.0, 1.0))
+ val dy = Vectors.dense(2.0, 1.0, 0.0)
+
+ assert(dot(sx, sy) ~== 2.0 absTol 1e-15)
+ assert(dot(sy, sx) ~== 2.0 absTol 1e-15)
+ assert(dot(sx, dy) ~== 2.0 absTol 1e-15)
+ assert(dot(dy, sx) ~== 2.0 absTol 1e-15)
+ assert(dot(dx, dy) ~== 2.0 absTol 1e-15)
+ assert(dot(dy, dx) ~== 2.0 absTol 1e-15)
+
+ assert(dot(sx, sx) ~== 5.0 absTol 1e-15)
+ assert(dot(dx, dx) ~== 5.0 absTol 1e-15)
+ assert(dot(sx, dx) ~== 5.0 absTol 1e-15)
+ assert(dot(dx, sx) ~== 5.0 absTol 1e-15)
+
+ val sx1 = Vectors.sparse(10, Array(0, 3, 5, 7, 8), Array(1.0, 2.0, 3.0, 4.0, 5.0))
+ val sx2 = Vectors.sparse(10, Array(1, 3, 6, 7, 9), Array(1.0, 2.0, 3.0, 4.0, 5.0))
+ assert(dot(sx1, sx2) ~== 20.0 absTol 1e-15)
+ assert(dot(sx2, sx1) ~== 20.0 absTol 1e-15)
+
+ withClue("vector sizes must match") {
+ intercept[Exception] {
+ dot(sx, Vectors.dense(2.0, 1.0))
+ }
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 7972ceea1f..cd651fe2d2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -125,4 +125,34 @@ class VectorsSuite extends FunSuite {
}
}
}
+
+ test("zeros") {
+ assert(Vectors.zeros(3) === Vectors.dense(0.0, 0.0, 0.0))
+ }
+
+ test("Vector.copy") {
+ val sv = Vectors.sparse(4, Array(0, 2), Array(1.0, 2.0))
+ val svCopy = sv.copy
+ (sv, svCopy) match {
+ case (sv: SparseVector, svCopy: SparseVector) =>
+ assert(sv.size === svCopy.size)
+ assert(sv.indices === svCopy.indices)
+ assert(sv.values === svCopy.values)
+ assert(!sv.indices.eq(svCopy.indices))
+ assert(!sv.values.eq(svCopy.values))
+ case _ =>
+ throw new RuntimeException(s"copy returned ${svCopy.getClass} on ${sv.getClass}.")
+ }
+
+ val dv = Vectors.dense(1.0, 0.0, 2.0)
+ val dvCopy = dv.copy
+ (dv, dvCopy) match {
+ case (dv: DenseVector, dvCopy: DenseVector) =>
+ assert(dv.size === dvCopy.size)
+ assert(dv.values === dvCopy.values)
+ assert(!dv.values.eq(dvCopy.values))
+ case _ =>
+ throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.")
+ }
+ }
}