diff options
author | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-19 22:42:12 -0500 |
---|---|---|
committer | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-19 22:42:12 -0500 |
commit | ded67ee90c2c0b22d67e623156a3f6cce8573abd (patch) | |
tree | bb6d131c4cfdad3d115c014ca529cbdb1afbe286 /mllib | |
parent | 2a41c9aad3d0a8477a11bf910fa57b49ea4dc6dc (diff) | |
download | spark-ded67ee90c2c0b22d67e623156a3f6cce8573abd.tar.gz spark-ded67ee90c2c0b22d67e623156a3f6cce8573abd.tar.bz2 spark-ded67ee90c2c0b22d67e623156a3f6cce8573abd.zip |
Bindings for linear, Lasso, and ridge regression.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala | 42 |
1 files changed, 37 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index 3daf5dcb39..c9bd7c6415 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -1,5 +1,6 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression._ +import org.apache.spark.rdd.RDD import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.DoubleBuffer @@ -38,14 +39,45 @@ class PythonMLLibAPI extends Serializable { return bytes } - def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]): - java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x)) - .map(v => LabeledPoint(v(0), v.slice(1, v.length))) - val model = LinearRegressionWithSGD.train(data, 222) + def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, + dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): + java.util.LinkedList[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => { + val x = deserializeDoubleVector(xBytes) + LabeledPoint(x(0), x.slice(1, x.length)) + }) + val initialWeights = deserializeDoubleVector(initialWeightsBA) + val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() ret.add(serializeDoubleVector(model.weights)) ret.add(model.intercept: java.lang.Double) return ret } + + def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LinearRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA); + } + + def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LassoWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA); + } + + def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA); + } } |