diff options
author | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-21 14:54:13 -0500 |
---|---|---|
committer | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-21 14:54:13 -0500 |
commit | 20f85eca3d924aecd0fcf61cd516d9ac8e369dc1 (patch) | |
tree | 2122eeebff4fb21339ff289547734f2151faf999 | |
parent | 076fc1622190d342e20592c00ca19f8c0a56997f (diff) | |
download | spark-20f85eca3d924aecd0fcf61cd516d9ac8e369dc1.tar.gz spark-20f85eca3d924aecd0fcf61cd516d9ac8e369dc1.tar.bz2 spark-20f85eca3d924aecd0fcf61cd516d9ac8e369dc1.zip |
Java stubs for ALSModel.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index 6472bf6367..4620cab175 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -19,6 +19,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.recommendation._ import org.apache.spark.rdd.RDD import java.nio.ByteBuffer import java.nio.ByteOrder @@ -194,4 +195,37 @@ class PythonMLLibAPI extends Serializable { ret.add(serializeDoubleMatrix(model.clusterCenters)) return ret } + + private def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + return new Rating(user, product, rating) + } + + /** + * Java stub for Python mllib ALSModel.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.train(ratings, rank, iterations, lambda, blocks) + } + + /** + * Java stub for Python mllib ALSModel.trainImplicit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) + } } |