aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/recommendation.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/recommendation.py')
-rw-r--r--python/pyspark/mllib/recommendation.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index c5c4c13dae..80e0a356bb 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import array
from collections import namedtuple
from pyspark import SparkContext
@@ -104,14 +105,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)"
first = user_product.first()
assert len(first) == 2, "user_product should be RDD of (user, product)"
- user_product = user_product.map(lambda (u, p): (int(u), int(p)))
+ user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1])))
return self.call("predict", user_product)
def userFeatures(self):
- return self.call("getUserFeatures")
+ return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v))
def productFeatures(self):
- return self.call("getProductFeatures")
+ return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v))
@classmethod
def load(cls, sc, path):