diff options
author | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-19 01:22:18 -0500 |
---|---|---|
committer | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-19 01:29:09 -0500 |
commit | 95915f8b3b6d07a9dddb09a637aa23c8622bff9b (patch) | |
tree | bc22a7cb8758079a0c8896d022dca0b418e66ec8 /mllib | |
parent | d3b1af4b6c7766bbf7a09ee6d5c1b13eda6b098f (diff) | |
download | spark-95915f8b3b6d07a9dddb09a637aa23c8622bff9b.tar.gz spark-95915f8b3b6d07a9dddb09a637aa23c8622bff9b.tar.bz2 spark-95915f8b3b6d07a9dddb09a637aa23c8622bff9b.zip |
First cut at python mllib bindings. Only LinearRegression is supported.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala new file mode 100644 index 0000000000..19d2e9a773 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -0,0 +1,51 @@ +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression._ +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.DoubleBuffer + +class PythonMLLibAPI extends Serializable { + def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { + val packetLength = bytes.length; + if (packetLength < 16) { + throw new IllegalArgumentException("Byte array too short."); + } + val bb = ByteBuffer.wrap(bytes); + bb.order(ByteOrder.nativeOrder()); + val magic = bb.getLong(); + if (magic != 1) { + throw new IllegalArgumentException("Magic " + magic + " is wrong."); + } + val length = bb.getLong(); + if (packetLength != 16 + 8 * length) { + throw new IllegalArgumentException("Length " + length + "is wrong."); + } + val db = bb.asDoubleBuffer(); + val ans = new Array[Double](length.toInt); + db.get(ans); + return ans; + } + + def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length; + val bytes = new Array[Byte](16 + 8 * len); + val bb = ByteBuffer.wrap(bytes); + bb.order(ByteOrder.nativeOrder()); + bb.putLong(1); + bb.putLong(len); + val db = bb.asDoubleBuffer(); + db.put(doubles); + return bytes; + } + + def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]): + java.util.List[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x)) + .map(v => LabeledPoint(v(0), v.slice(1, v.length))); + val model = LinearRegressionWithSGD.train(data, 222); + val ret = new java.util.LinkedList[java.lang.Object](); + ret.add(serializeDoubleVector(model.weights)); + ret.add(model.intercept: java.lang.Double); + return ret; + } +} |