aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorTor Myklebust <tmyklebu@gmail.com>2013-12-20 00:12:22 -0500
committerTor Myklebust <tmyklebu@gmail.com>2013-12-20 00:12:22 -0500
commitf99970e8cdc85eae33999b57a4c5c1893fe3727a (patch)
tree18106473cdfa36ccc058b050cd160446724ab47f /mllib
parent2328bdd00f701ca3b1bc7fdf8b2968fafc58fd11 (diff)
downloadspark-f99970e8cdc85eae33999b57a4c5c1893fe3727a.tar.gz
spark-f99970e8cdc85eae33999b57a4c5c1893fe3727a.tar.bz2
spark-f99970e8cdc85eae33999b57a4c5c1893fe3727a.zip
Scala classification and clustering stubs; matrix serialization/deserialization.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala82
1 files changed, 79 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
index c9bd7c6415..bcf2f07517 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
@@ -1,5 +1,7 @@
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.classification._
+import org.apache.spark.mllib.clustering._
import org.apache.spark.rdd.RDD
import java.nio.ByteBuffer
import java.nio.ByteOrder
@@ -39,6 +41,52 @@ class PythonMLLibAPI extends Serializable {
return bytes
}
+ def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
+ val packetLength = bytes.length
+ if (packetLength < 24) {
+ throw new IllegalArgumentException("Byte array too short.")
+ }
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ val magic = bb.getLong()
+ if (magic != 2) {
+ throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ }
+ val rows = bb.getLong()
+ val cols = bb.getLong()
+ if (packetLength != 24 + 8 * rows * cols) {
+ throw new IllegalArgumentException("Size " + rows + "x" + cols + "is wrong.")
+ }
+ val db = bb.asDoubleBuffer()
+ val ans = new Array[Array[Double]](rows.toInt)
+ var i = 0
+ for (i <- 0 until rows.toInt) {
+ ans(i) = new Array[Double](cols.toInt)
+ db.get(ans(i))
+ }
+ return ans
+ }
+
+ def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
+ val rows = doubles.length
+ var cols = 0
+ if (rows > 0) {
+ cols = doubles(0).length
+ }
+ val bytes = new Array[Byte](24 + 8 * rows * cols)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putLong(2)
+ bb.putLong(rows)
+ bb.putLong(cols)
+ val db = bb.asDoubleBuffer()
+ var i = 0
+ for (i <- 0 until rows) {
+ db.put(doubles(i))
+ }
+ return bytes
+ }
+
def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
java.util.LinkedList[java.lang.Object] = {
@@ -60,7 +108,7 @@ class PythonMLLibAPI extends Serializable {
return trainRegressionModel((data, initialWeights) =>
LinearRegressionWithSGD.train(data, numIterations, stepSize,
miniBatchFraction, initialWeights),
- dataBytesJRDD, initialWeightsBA);
+ dataBytesJRDD, initialWeightsBA)
}
def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
@@ -69,7 +117,7 @@ class PythonMLLibAPI extends Serializable {
return trainRegressionModel((data, initialWeights) =>
LassoWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
- dataBytesJRDD, initialWeightsBA);
+ dataBytesJRDD, initialWeightsBA)
}
def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
@@ -78,6 +126,34 @@ class PythonMLLibAPI extends Serializable {
return trainRegressionModel((data, initialWeights) =>
RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
- dataBytesJRDD, initialWeightsBA);
+ dataBytesJRDD, initialWeightsBA)
+ }
+
+ def trainSVMModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+ stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ SVMWithSGD.train(data, numIterations, stepSize, regParam,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA)
+ }
+
+ def trainLogisticRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ LogisticRegressionWithSGD.train(data, numIterations, stepSize,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA)
+ }
+
+ def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
+ maxIterations: Int, runs: Int, initializationMode: String):
+ java.util.List[java.lang.Object] = {
+ val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
+ val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
+ val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(serializeDoubleMatrix(model.clusterCenters))
+ return ret
}
}