aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-01-27 01:29:14 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-27 01:29:14 -0800
commit7b0ed797958a91cda73baa7aa49ce66bfcb6b64b (patch)
tree988570edd6b652ea6977af96135b4dd5f9c6d43f /mllib
parentd6894b1c5314c751cfdaf78005b99b2104e6e4d1 (diff)
downloadspark-7b0ed797958a91cda73baa7aa49ce66bfcb6b64b.tar.gz
spark-7b0ed797958a91cda73baa7aa49ce66bfcb6b64b.tar.bz2
spark-7b0ed797958a91cda73baa7aa49ce66bfcb6b64b.zip
[SPARK-5419][Mllib] Fix the logic in Vectors.sqdist
The current implementation in Vectors.sqdist is not efficient because of allocating temp arrays. There is also a bug in the code `v1.indices.length / v1.size < 0.5`. This pr fixes the bug and refactors sqdist without allocating new arrays. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #4217 from viirya/fix_sqdist and squashes the following commits: e8b0b3d [Liang-Chi Hsieh] For review comments. 314c424 [Liang-Chi Hsieh] Fix sqdist bug.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala19
1 files changed, 12 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index b3022add38..2834ea75ce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -371,18 +371,23 @@ object Vectors {
squaredDistance += score * score
}
- case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 =>
+ case (v1: SparseVector, v2: DenseVector) =>
squaredDistance = sqdist(v1, v2)
- case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 =>
+ case (v1: DenseVector, v2: SparseVector) =>
squaredDistance = sqdist(v2, v1)
- // When a SparseVector is approximately dense, we treat it as a DenseVector
- case (v1, v2) =>
- squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) =>
- val score = elems._1 - elems._2
- distance + score * score
+ case (DenseVector(vv1), DenseVector(vv2)) =>
+ var kv = 0
+ val sz = vv1.size
+ while (kv < sz) {
+ val score = vv1(kv) - vv2(kv)
+ squaredDistance += score * score
+ kv += 1
}
+ case _ =>
+ throw new IllegalArgumentException("Do not support vector type " + v1.getClass +
+ " and " + v2.getClass)
}
squaredDistance
}