diff options
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala | 42 |
1 files changed, 37 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index 3daf5dcb39..c9bd7c6415 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -1,5 +1,6 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression._ +import org.apache.spark.rdd.RDD import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.DoubleBuffer @@ -38,14 +39,45 @@ class PythonMLLibAPI extends Serializable { return bytes } - def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]): - java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x)) - .map(v => LabeledPoint(v(0), v.slice(1, v.length))) - val model = LinearRegressionWithSGD.train(data, 222) + def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, + dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): + java.util.LinkedList[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => { + val x = deserializeDoubleVector(xBytes) + LabeledPoint(x(0), x.slice(1, x.length)) + }) + val initialWeights = deserializeDoubleVector(initialWeightsBA) + val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() ret.add(serializeDoubleVector(model.weights)) ret.add(model.intercept: java.lang.Double) return ret } + + def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LinearRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA); + } + + def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LassoWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA); + } + + def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA); + } } |