aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorTor Myklebust <tmyklebu@gmail.com>2013-12-19 22:42:12 -0500
committerTor Myklebust <tmyklebu@gmail.com>2013-12-19 22:42:12 -0500
commitded67ee90c2c0b22d67e623156a3f6cce8573abd (patch)
treebb6d131c4cfdad3d115c014ca529cbdb1afbe286 /mllib
parent2a41c9aad3d0a8477a11bf910fa57b49ea4dc6dc (diff)
downloadspark-ded67ee90c2c0b22d67e623156a3f6cce8573abd.tar.gz
spark-ded67ee90c2c0b22d67e623156a3f6cce8573abd.tar.bz2
spark-ded67ee90c2c0b22d67e623156a3f6cce8573abd.zip
Bindings for linear, Lasso, and ridge regression.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala42
1 files changed, 37 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
index 3daf5dcb39..c9bd7c6415 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
@@ -1,5 +1,6 @@
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression._
+import org.apache.spark.rdd.RDD
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.DoubleBuffer
@@ -38,14 +39,45 @@ class PythonMLLibAPI extends Serializable {
return bytes
}
- def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]):
- java.util.List[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x))
- .map(v => LabeledPoint(v(0), v.slice(1, v.length)))
- val model = LinearRegressionWithSGD.train(data, 222)
+ def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
+ dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
+ java.util.LinkedList[java.lang.Object] = {
+ val data = dataBytesJRDD.rdd.map(xBytes => {
+ val x = deserializeDoubleVector(xBytes)
+ LabeledPoint(x(0), x.slice(1, x.length))
+ })
+ val initialWeights = deserializeDoubleVector(initialWeightsBA)
+ val model = trainFunc(data, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleVector(model.weights))
ret.add(model.intercept: java.lang.Double)
return ret
}
+
+ def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ LinearRegressionWithSGD.train(data, numIterations, stepSize,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA);
+ }
+
+ def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+ stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ LassoWithSGD.train(data, numIterations, stepSize, regParam,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA);
+ }
+
+ def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+ stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA);
+ }
}