aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala11
1 files changed, 6 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 7ee0224ad4..b3022add38 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -333,7 +333,7 @@ object Vectors {
math.pow(sum, 1.0 / p)
}
}
-
+
/**
* Returns the squared distance between two Vectors.
* @param v1 first Vector.
@@ -341,8 +341,9 @@ object Vectors {
* @return squared distance between two Vectors.
*/
def sqdist(v1: Vector, v2: Vector): Double = {
+ require(v1.size == v2.size, "vector dimension mismatch")
var squaredDistance = 0.0
- (v1, v2) match {
+ (v1, v2) match {
case (v1: SparseVector, v2: SparseVector) =>
val v1Values = v1.values
val v1Indices = v1.indices
@@ -350,12 +351,12 @@ object Vectors {
val v2Indices = v2.indices
val nnzv1 = v1Indices.size
val nnzv2 = v2Indices.size
-
+
var kv1 = 0
var kv2 = 0
while (kv1 < nnzv1 || kv2 < nnzv2) {
var score = 0.0
-
+
if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) {
score = v1Values(kv1)
kv1 += 1
@@ -397,7 +398,7 @@ object Vectors {
val nnzv1 = indices.size
val nnzv2 = v2.size
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
-
+
while (kv2 < nnzv2) {
var score = 0.0
if (kv2 != iv1) {