diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2015-01-06 14:00:45 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-01-06 14:00:45 -0800 |
commit | bb38ebb1abd26b57525d7d29703fd449e40cd6de (patch) | |
tree | 643f535bca24bdd5b9da902769380cb10b315da0 /mllib | |
parent | 4108e5f36f8553bd728fd271baa69f7dfcc68d9b (diff) | |
download | spark-bb38ebb1abd26b57525d7d29703fd449e40cd6de.tar.gz spark-bb38ebb1abd26b57525d7d29703fd449e40cd6de.tar.bz2 spark-bb38ebb1abd26b57525d7d29703fd449e40cd6de.zip |
[SPARK-5050][Mllib] Add unit test for sqdist
Related to #3643. Follow the previous suggestion to add unit test for `sqdist` in `VectorsSuite`.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #3869 from viirya/sqdist_test and squashes the following commits:
fb743da [Liang-Chi Hsieh] Modified for comment and fix bug.
90a08f3 [Liang-Chi Hsieh] Modified for comment.
39a3ca6 [Liang-Chi Hsieh] Take care of special case.
b789f42 [Liang-Chi Hsieh] More proper unit test with random sparsity pattern.
c36be68 [Liang-Chi Hsieh] Add unit test for sqdist.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 5 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala | 31 |
2 files changed, 33 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6a782b079a..d40f13342a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -373,8 +373,9 @@ object Vectors { var kv2 = 0 val indices = v1.indices var squaredDistance = 0.0 - var iv1 = indices(kv1) + val nnzv1 = indices.size val nnzv2 = v2.size + var iv1 = if (nnzv1 > 0) indices(kv1) else -1 while (kv2 < nnzv2) { var score = 0.0 @@ -382,7 +383,7 @@ object Vectors { score = v2(kv2) } else { score = v1.values(kv1) - v2(kv2) - if (kv1 < indices.length - 1) { + if (kv1 < nnzv1 - 1) { kv1 += 1 iv1 = indices(kv1) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index f99f014509..85ac8ccebf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.mllib.linalg -import breeze.linalg.{DenseMatrix => BDM} +import scala.util.Random + +import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} import org.scalatest.FunSuite import org.apache.spark.SparkException @@ -175,6 +177,33 @@ class VectorsSuite extends FunSuite { assert(v.size === x.rows) } + test("sqdist") { + val random = new Random() + for (m <- 1 until 1000 by 100) { + val nnz = random.nextInt(m) + + val indices1 = random.shuffle(0 to m - 1).slice(0, nnz).sorted.toArray + val values1 = Array.fill(nnz)(random.nextDouble) + val sparseVector1 = Vectors.sparse(m, indices1, values1) + + val indices2 = random.shuffle(0 to m - 1).slice(0, nnz).sorted.toArray + val values2 = Array.fill(nnz)(random.nextDouble) + val sparseVector2 = Vectors.sparse(m, indices2, values2) + + val denseVector1 = Vectors.dense(sparseVector1.toArray) + val denseVector2 = Vectors.dense(sparseVector2.toArray) + + val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze) + + // SparseVector vs. SparseVector + assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) + // DenseVector vs. SparseVector + assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8) + // DenseVector vs. DenseVector + assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8) + } + } + test("foreachActive") { val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0) val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0))) |