aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/_common.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index a411a5d591..e609b60a0f 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -454,7 +454,7 @@ def _squared_distance(v1, v2):
v2 = _convert_vector(v2)
if type(v1) == ndarray and type(v2) == ndarray:
diff = v1 - v2
- return diff.dot(diff)
+ return numpy.dot(diff, diff)
elif type(v1) == ndarray:
return v2.squared_distance(v1)
else:
@@ -469,10 +469,12 @@ def _dot(vec, target):
calling numpy.dot of the two vectors, but for SciPy ones, we
have to transpose them because they're column vectors.
"""
- if type(vec) == ndarray or type(vec) == SparseVector:
+ if type(vec) == ndarray:
+ return numpy.dot(vec, target)
+ elif type(vec) == SparseVector:
return vec.dot(target)
elif type(vec) == list:
- return _convert_vector(vec).dot(target)
+ return numpy.dot(_convert_vector(vec), target)
else:
return vec.transpose().dot(target)[0]