aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/classification.py
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2014-04-22 11:06:18 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-22 11:06:18 -0700
commitc919798f0912dc03c8365b9a384d9ee6d5b25c51 (patch)
tree386eac712d26333b20a3950a551b6906a972824a /python/pyspark/mllib/classification.py
parent0f87e6ad4366a8c453a7415bc89399030003c264 (diff)
downloadspark-c919798f0912dc03c8365b9a384d9ee6d5b25c51.tar.gz
spark-c919798f0912dc03c8365b9a384d9ee6d5b25c51.tar.bz2
spark-c919798f0912dc03c8365b9a384d9ee6d5b25c51.zip
fix bugs of dot in python
If there are no `transpose()` in `self.theta`, a *ValueError: matrices are not aligned* is occurring. The former test case just ignore this situation. Author: Xusen Yin <yinxusen@gmail.com> Closes #463 from yinxusen/python-naive-bayes and squashes the following commits: fcbe3bc [Xusen Yin] fix bugs of dot in python
Diffstat (limited to 'python/pyspark/mllib/classification.py')
-rw-r--r--python/pyspark/mllib/classification.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 3a23e0801f..c5844597c9 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -154,7 +154,7 @@ class NaiveBayesModel(object):
def predict(self, x):
"""Return the most likely class for a data vector x"""
- return self.labels[numpy.argmax(self.pi + _dot(x, self.theta))]
+ return self.labels[numpy.argmax(self.pi + _dot(x, self.theta.transpose()))]
class NaiveBayes(object):
@classmethod