aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-01-06 14:00:45 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-06 14:00:45 -0800
commitbb38ebb1abd26b57525d7d29703fd449e40cd6de (patch)
tree643f535bca24bdd5b9da902769380cb10b315da0
parent4108e5f36f8553bd728fd271baa69f7dfcc68d9b (diff)
downloadspark-bb38ebb1abd26b57525d7d29703fd449e40cd6de.tar.gz
spark-bb38ebb1abd26b57525d7d29703fd449e40cd6de.tar.bz2
spark-bb38ebb1abd26b57525d7d29703fd449e40cd6de.zip
[SPARK-5050][Mllib] Add unit test for sqdist
Related to #3643. Follow the previous suggestion to add unit test for `sqdist` in `VectorsSuite`. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #3869 from viirya/sqdist_test and squashes the following commits: fb743da [Liang-Chi Hsieh] Modified for comment and fix bug. 90a08f3 [Liang-Chi Hsieh] Modified for comment. 39a3ca6 [Liang-Chi Hsieh] Take care of special case. b789f42 [Liang-Chi Hsieh] More proper unit test with random sparsity pattern. c36be68 [Liang-Chi Hsieh] Add unit test for sqdist.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala31
2 files changed, 33 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6a782b079a..d40f13342a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -373,8 +373,9 @@ object Vectors {
var kv2 = 0
val indices = v1.indices
var squaredDistance = 0.0
- var iv1 = indices(kv1)
+ val nnzv1 = indices.size
val nnzv2 = v2.size
+ var iv1 = if (nnzv1 > 0) indices(kv1) else -1
while (kv2 < nnzv2) {
var score = 0.0
@@ -382,7 +383,7 @@ object Vectors {
score = v2(kv2)
} else {
score = v1.values(kv1) - v2(kv2)
- if (kv1 < indices.length - 1) {
+ if (kv1 < nnzv1 - 1) {
kv1 += 1
iv1 = indices(kv1)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index f99f014509..85ac8ccebf 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.mllib.linalg
-import breeze.linalg.{DenseMatrix => BDM}
+import scala.util.Random
+
+import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
import org.scalatest.FunSuite
import org.apache.spark.SparkException
@@ -175,6 +177,33 @@ class VectorsSuite extends FunSuite {
assert(v.size === x.rows)
}
+ test("sqdist") {
+ val random = new Random()
+ for (m <- 1 until 1000 by 100) {
+ val nnz = random.nextInt(m)
+
+ val indices1 = random.shuffle(0 to m - 1).slice(0, nnz).sorted.toArray
+ val values1 = Array.fill(nnz)(random.nextDouble)
+ val sparseVector1 = Vectors.sparse(m, indices1, values1)
+
+ val indices2 = random.shuffle(0 to m - 1).slice(0, nnz).sorted.toArray
+ val values2 = Array.fill(nnz)(random.nextDouble)
+ val sparseVector2 = Vectors.sparse(m, indices2, values2)
+
+ val denseVector1 = Vectors.dense(sparseVector1.toArray)
+ val denseVector2 = Vectors.dense(sparseVector2.toArray)
+
+ val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze)
+
+ // SparseVector vs. SparseVector
+ assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
+ // DenseVector vs. SparseVector
+ assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
+ // DenseVector vs. DenseVector
+ assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8)
+ }
+ }
+
test("foreachActive") {
val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))