diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2015-01-27 01:29:14 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-01-27 01:29:14 -0800 |
commit | 7b0ed797958a91cda73baa7aa49ce66bfcb6b64b (patch) | |
tree | 988570edd6b652ea6977af96135b4dd5f9c6d43f | |
parent | d6894b1c5314c751cfdaf78005b99b2104e6e4d1 (diff) | |
download | spark-7b0ed797958a91cda73baa7aa49ce66bfcb6b64b.tar.gz spark-7b0ed797958a91cda73baa7aa49ce66bfcb6b64b.tar.bz2 spark-7b0ed797958a91cda73baa7aa49ce66bfcb6b64b.zip |
[SPARK-5419][Mllib] Fix the logic in Vectors.sqdist
The current implementation in Vectors.sqdist is not efficient because of allocating temp arrays. There is also a bug in the code `v1.indices.length / v1.size < 0.5`. This pr fixes the bug and refactors sqdist without allocating new arrays.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #4217 from viirya/fix_sqdist and squashes the following commits:
e8b0b3d [Liang-Chi Hsieh] For review comments.
314c424 [Liang-Chi Hsieh] Fix sqdist bug.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index b3022add38..2834ea75ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -371,18 +371,23 @@ object Vectors { squaredDistance += score * score } - case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => + case (v1: SparseVector, v2: DenseVector) => squaredDistance = sqdist(v1, v2) - case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 => + case (v1: DenseVector, v2: SparseVector) => squaredDistance = sqdist(v2, v1) - // When a SparseVector is approximately dense, we treat it as a DenseVector - case (v1, v2) => - squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) => - val score = elems._1 - elems._2 - distance + score * score + case (DenseVector(vv1), DenseVector(vv2)) => + var kv = 0 + val sz = vv1.size + while (kv < sz) { + val score = vv1(kv) - vv2(kv) + squaredDistance += score * score + kv += 1 } + case _ => + throw new IllegalArgumentException("Do not support vector type " + v1.getClass + + " and " + v2.getClass) } squaredDistance } |