aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHossein Falaki <falaki@gmail.com>2014-01-04 16:23:17 -0800
committerHossein Falaki <falaki@gmail.com>2014-01-04 16:23:17 -0800
commit8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1 (patch)
tree2afba993eb1cd78796f4fbce944733ec26c205e6
parentdfe57fa84cea9d8bbca9a89a293efcaa95eae9e7 (diff)
downloadspark-8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1.tar.gz
spark-8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1.tar.bz2
spark-8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1.zip
Added python binding for bulk recommendation
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala10
-rw-r--r--python/pyspark/mllib/_common.py10
-rw-r--r--python/pyspark/mllib/recommendation.py10
4 files changed, 46 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 8247c1ebc5..be2628fac5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -206,6 +206,24 @@ class PythonMLLibAPI extends Serializable {
return new Rating(user, product, rating)
}
+ private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
+ val bb = ByteBuffer.wrap(tupleBytes)
+ bb.order(ByteOrder.nativeOrder())
+ val v1 = bb.getInt()
+ val v2 = bb.getInt()
+ (v1, v2)
+ }
+
+ private[spark] def serializeRating(rate: Rating): Array[Byte] = {
+ val bytes = new Array[Byte](24)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putDouble(rate.user.toDouble)
+ bb.putDouble(rate.product.toDouble)
+ bb.putDouble(rate.rating)
+ bytes
+ }
+
/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 8caecf0fa1..2c3e828300 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -19,9 +19,11 @@ package org.apache.spark.mllib.recommendation
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.api.python.PythonMLLibAPI
import org.jblas._
-import java.nio.{ByteOrder, ByteBuffer}
+import org.apache.spark.api.java.JavaRDD
+
/**
* Model representing the result of matrix factorization.
@@ -65,6 +67,12 @@ class MatrixFactorizationModel(
}
}
+ def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
+ val pythonAPI = new PythonMLLibAPI()
+ val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
+ predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))
+ }
+
// TODO: Figure out what other good bulk prediction methods would look like.
// Probably want a way to get the top users for a product or vice-versa.
}
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index e74ba0fabc..c818fc4d97 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -213,6 +213,16 @@ def _serialize_rating(r):
intpart[0], intpart[1], doublepart[0] = r
return ba
+def _deserialize_rating(ba):
+ ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C')
+ return ar.copy()
+
+def _serialize_tuple(t):
+ ba = bytearray(8)
+ intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
+ intpart[0], intpart[1] = t
+ return ba
+
def _test():
import doctest
globs = globals().copy()
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 14d06cba21..c81b482a87 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -20,7 +20,10 @@ from pyspark.mllib._common import \
_get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
_serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \
- _get_initial_weights, _serialize_rating, _regression_train_wrapper
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+ _serialize_tuple, _deserialize_rating
+from pyspark.serializers import BatchedSerializer
+from pyspark.rdd import RDD
class MatrixFactorizationModel(object):
"""A matrix factorisation model trained by regularized alternating
@@ -45,6 +48,11 @@ class MatrixFactorizationModel(object):
def predict(self, user, product):
return self._java_model.predict(user, product)
+ def predictAll(self, usersProducts):
+ usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
+ return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd),
+ self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize))
+
class ALS(object):
@classmethod
def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):