diff options
author | Hossein Falaki <falaki@gmail.com> | 2014-01-04 16:23:17 -0800 |
---|---|---|
committer | Hossein Falaki <falaki@gmail.com> | 2014-01-04 16:23:17 -0800 |
commit | 8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1 (patch) | |
tree | 2afba993eb1cd78796f4fbce944733ec26c205e6 /mllib | |
parent | dfe57fa84cea9d8bbca9a89a293efcaa95eae9e7 (diff) | |
download | spark-8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1.tar.gz spark-8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1.tar.bz2 spark-8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1.zip |
Added python binding for bulk recommendation
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 18 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala | 10 |
2 files changed, 27 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 8247c1ebc5..be2628fac5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -206,6 +206,24 @@ class PythonMLLibAPI extends Serializable { return new Rating(user, product, rating) } + private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { + val bb = ByteBuffer.wrap(tupleBytes) + bb.order(ByteOrder.nativeOrder()) + val v1 = bb.getInt() + val v2 = bb.getInt() + (v1, v2) + } + + private[spark] def serializeRating(rate: Rating): Array[Byte] = { + val bytes = new Array[Byte](24) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putDouble(rate.user.toDouble) + bb.putDouble(rate.product.toDouble) + bb.putDouble(rate.rating) + bytes + } + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 8caecf0fa1..2c3e828300 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -19,9 +19,11 @@ package org.apache.spark.mllib.recommendation import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.api.python.PythonMLLibAPI import org.jblas._ -import java.nio.{ByteOrder, ByteBuffer} +import org.apache.spark.api.java.JavaRDD + /** * Model representing the result of matrix factorization. @@ -65,6 +67,12 @@ class MatrixFactorizationModel( } } + def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { + val pythonAPI = new PythonMLLibAPI() + val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes)) + predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) + } + // TODO: Figure out what other good bulk prediction methods would look like. // Probably want a way to get the top users for a product or vice-versa. } |