aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorTor Myklebust <tmyklebu@gmail.com>2013-12-21 14:54:13 -0500
committerTor Myklebust <tmyklebu@gmail.com>2013-12-21 14:54:13 -0500
commit20f85eca3d924aecd0fcf61cd516d9ac8e369dc1 (patch)
tree2122eeebff4fb21339ff289547734f2151faf999 /mllib
parent076fc1622190d342e20592c00ca19f8c0a56997f (diff)
downloadspark-20f85eca3d924aecd0fcf61cd516d9ac8e369dc1.tar.gz
spark-20f85eca3d924aecd0fcf61cd516d9ac8e369dc1.tar.bz2
spark-20f85eca3d924aecd0fcf61cd516d9ac8e369dc1.zip
Java stubs for ALSModel.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala34
1 files changed, 34 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
index 6472bf6367..4620cab175 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
@@ -19,6 +19,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.recommendation._
import org.apache.spark.rdd.RDD
import java.nio.ByteBuffer
import java.nio.ByteOrder
@@ -194,4 +195,37 @@ class PythonMLLibAPI extends Serializable {
ret.add(serializeDoubleMatrix(model.clusterCenters))
return ret
}
+
+ private def unpackRating(ratingBytes: Array[Byte]): Rating = {
+ val bb = ByteBuffer.wrap(ratingBytes)
+ bb.order(ByteOrder.nativeOrder())
+ val user = bb.getInt()
+ val product = bb.getInt()
+ val rating = bb.getDouble()
+ return new Rating(user, product, rating)
+ }
+
+ /**
+ * Java stub for Python mllib ALSModel.train(). This stub returns a handle
+ * to the Java object instead of the content of the Java object. Extra care
+ * needs to be taken in the Python code to ensure it gets freed on exit; see
+ * the Py4J documentation.
+ */
+ def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
+ iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
+ val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
+ return ALS.train(ratings, rank, iterations, lambda, blocks)
+ }
+
+ /**
+ * Java stub for Python mllib ALSModel.trainImplicit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ */
+ def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
+ iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
+ val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
+ return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
+ }
}