aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/mllib/linalg.py20
1 files changed, 8 insertions, 12 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 9959a01cce..12d8dbbb92 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -590,18 +590,14 @@ class SparseVector(Vector):
return np.dot(other.array[self.indices], self.values)
elif isinstance(other, SparseVector):
- result = 0.0
- i, j = 0, 0
- while i < len(self.indices) and j < len(other.indices):
- if self.indices[i] == other.indices[j]:
- result += self.values[i] * other.values[j]
- i += 1
- j += 1
- elif self.indices[i] < other.indices[j]:
- i += 1
- else:
- j += 1
- return result
+ # Find out common indices.
+ self_cmind = np.in1d(self.indices, other.indices, assume_unique=True)
+ self_values = self.values[self_cmind]
+ if self_values.size == 0:
+ return 0.0
+ else:
+ other_cmind = np.in1d(other.indices, self.indices, assume_unique=True)
+ return np.dot(self_values, other.values[other_cmind])
else:
return self.dot(_convert_to_vector(other))