diff options
author | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-20 00:12:22 -0500 |
---|---|---|
committer | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-20 00:12:22 -0500 |
commit | f99970e8cdc85eae33999b57a4c5c1893fe3727a (patch) | |
tree | 18106473cdfa36ccc058b050cd160446724ab47f | |
parent | 2328bdd00f701ca3b1bc7fdf8b2968fafc58fd11 (diff) | |
download | spark-f99970e8cdc85eae33999b57a4c5c1893fe3727a.tar.gz spark-f99970e8cdc85eae33999b57a4c5c1893fe3727a.tar.bz2 spark-f99970e8cdc85eae33999b57a4c5c1893fe3727a.zip |
Scala classification and clustering stubs; matrix serialization/deserialization.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala | 82 |
1 files changed, 79 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index c9bd7c6415..bcf2f07517 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -1,5 +1,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.classification._ +import org.apache.spark.mllib.clustering._ import org.apache.spark.rdd.RDD import java.nio.ByteBuffer import java.nio.ByteOrder @@ -39,6 +41,52 @@ class PythonMLLibAPI extends Serializable { return bytes } + def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + val packetLength = bytes.length + if (packetLength < 24) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.getLong() + if (magic != 2) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val rows = bb.getLong() + val cols = bb.getLong() + if (packetLength != 24 + 8 * rows * cols) { + throw new IllegalArgumentException("Size " + rows + "x" + cols + "is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Array[Double]](rows.toInt) + var i = 0 + for (i <- 0 until rows.toInt) { + ans(i) = new Array[Double](cols.toInt) + db.get(ans(i)) + } + return ans + } + + def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + val rows = doubles.length + var cols = 0 + if (rows > 0) { + cols = doubles(0).length + } + val bytes = new Array[Byte](24 + 8 * rows * cols) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putLong(2) + bb.putLong(rows) + bb.putLong(cols) + val db = bb.asDoubleBuffer() + var i = 0 + for (i <- 0 until rows) { + db.put(doubles(i)) + } + return bytes + } + def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { @@ -60,7 +108,7 @@ class PythonMLLibAPI extends Serializable { return trainRegressionModel((data, initialWeights) => LinearRegressionWithSGD.train(data, numIterations, stepSize, miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA); + dataBytesJRDD, initialWeightsBA) } def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, @@ -69,7 +117,7 @@ class PythonMLLibAPI extends Serializable { return trainRegressionModel((data, initialWeights) => LassoWithSGD.train(data, numIterations, stepSize, regParam, miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA); + dataBytesJRDD, initialWeightsBA) } def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, @@ -78,6 +126,34 @@ class PythonMLLibAPI extends Serializable { return trainRegressionModel((data, initialWeights) => RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA); + dataBytesJRDD, initialWeightsBA) + } + + def trainSVMModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + SVMWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + def trainLogisticRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LogisticRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, + maxIterations: Int, runs: Int, initializationMode: String): + java.util.List[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + val model = KMeans.train(data, k, maxIterations, runs, initializationMode) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleMatrix(model.clusterCenters)) + return ret } } |