aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorTor Myklebust <tmyklebu@gmail.com>2013-12-19 01:22:18 -0500
committerTor Myklebust <tmyklebu@gmail.com>2013-12-19 01:29:09 -0500
commit95915f8b3b6d07a9dddb09a637aa23c8622bff9b (patch)
treebc22a7cb8758079a0c8896d022dca0b418e66ec8 /mllib
parentd3b1af4b6c7766bbf7a09ee6d5c1b13eda6b098f (diff)
downloadspark-95915f8b3b6d07a9dddb09a637aa23c8622bff9b.tar.gz
spark-95915f8b3b6d07a9dddb09a637aa23c8622bff9b.tar.bz2
spark-95915f8b3b6d07a9dddb09a637aa23c8622bff9b.zip
First cut at python mllib bindings. Only LinearRegression is supported.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala51
1 files changed, 51 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
new file mode 100644
index 0000000000..19d2e9a773
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
@@ -0,0 +1,51 @@
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression._
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import java.nio.DoubleBuffer
+
+class PythonMLLibAPI extends Serializable {
+ def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = {
+ val packetLength = bytes.length;
+ if (packetLength < 16) {
+ throw new IllegalArgumentException("Byte array too short.");
+ }
+ val bb = ByteBuffer.wrap(bytes);
+ bb.order(ByteOrder.nativeOrder());
+ val magic = bb.getLong();
+ if (magic != 1) {
+ throw new IllegalArgumentException("Magic " + magic + " is wrong.");
+ }
+ val length = bb.getLong();
+ if (packetLength != 16 + 8 * length) {
+ throw new IllegalArgumentException("Length " + length + "is wrong.");
+ }
+ val db = bb.asDoubleBuffer();
+ val ans = new Array[Double](length.toInt);
+ db.get(ans);
+ return ans;
+ }
+
+ def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = {
+ val len = doubles.length;
+ val bytes = new Array[Byte](16 + 8 * len);
+ val bb = ByteBuffer.wrap(bytes);
+ bb.order(ByteOrder.nativeOrder());
+ bb.putLong(1);
+ bb.putLong(len);
+ val db = bb.asDoubleBuffer();
+ db.put(doubles);
+ return bytes;
+ }
+
+ def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]):
+ java.util.List[java.lang.Object] = {
+ val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x))
+ .map(v => LabeledPoint(v(0), v.slice(1, v.length)));
+ val model = LinearRegressionWithSGD.train(data, 222);
+ val ret = new java.util.LinkedList[java.lang.Object]();
+ ret.add(serializeDoubleVector(model.weights));
+ ret.add(model.intercept: java.lang.Double);
+ return ret;
+ }
+}