diff options
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 232 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala | 13 |
2 files changed, 238 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala new file mode 100644 index 0000000000..8247c1ebc5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.classification._ +import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.recommendation._ +import org.apache.spark.rdd.RDD +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.DoubleBuffer + +/** + * The Java stubs necessary for the Python mllib bindings. + */ +class PythonMLLibAPI extends Serializable { + private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { + val packetLength = bytes.length + if (packetLength < 16) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.getLong() + if (magic != 1) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val length = bb.getLong() + if (packetLength != 16 + 8 * length) { + throw new IllegalArgumentException("Length " + length + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Double](length.toInt) + db.get(ans) + return ans + } + + private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length + val bytes = new Array[Byte](16 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putLong(1) + bb.putLong(len) + val db = bb.asDoubleBuffer() + db.put(doubles) + return bytes + } + + private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + val packetLength = bytes.length + if (packetLength < 24) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.getLong() + if (magic != 2) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val rows = bb.getLong() + val cols = bb.getLong() + if (packetLength != 24 + 8 * rows * cols) { + throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Array[Double]](rows.toInt) + var i = 0 + for (i <- 0 until rows.toInt) { + ans(i) = new Array[Double](cols.toInt) + db.get(ans(i)) + } + return ans + } + + private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + val rows = doubles.length + var cols = 0 + if (rows > 0) { + cols = doubles(0).length + } + val bytes = new Array[Byte](24 + 8 * rows * cols) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putLong(2) + bb.putLong(rows) + bb.putLong(cols) + val db = bb.asDoubleBuffer() + var i = 0 + for (i <- 0 until rows) { + db.put(doubles(i)) + } + return bytes + } + + private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, + dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): + java.util.LinkedList[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => { + val x = deserializeDoubleVector(xBytes) + LabeledPoint(x(0), x.slice(1, x.length)) + }) + val initialWeights = deserializeDoubleVector(initialWeightsBA) + val model = trainFunc(data, initialWeights) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleVector(model.weights)) + ret.add(model.intercept: java.lang.Double) + return ret + } + + /** + * Java stub for Python mllib LinearRegressionWithSGD.train() + */ + def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LinearRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib LassoWithSGD.train() + */ + def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LassoWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib RidgeRegressionWithSGD.train() + */ + def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib SVMWithSGD.train() + */ + def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, + stepSize: Double, regParam: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + SVMWithSGD.train(data, numIterations, stepSize, regParam, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib LogisticRegressionWithSGD.train() + */ + def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, stepSize: Double, miniBatchFraction: Double, + initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + return trainRegressionModel((data, initialWeights) => + LogisticRegressionWithSGD.train(data, numIterations, stepSize, + miniBatchFraction, initialWeights), + dataBytesJRDD, initialWeightsBA) + } + + /** + * Java stub for Python mllib KMeans.train() + */ + def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, + maxIterations: Int, runs: Int, initializationMode: String): + java.util.List[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + val model = KMeans.train(data, k, maxIterations, runs, initializationMode) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleMatrix(model.clusterCenters)) + return ret + } + + private def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + return new Rating(user, product, rating) + } + + /** + * Java stub for Python mllib ALS.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.train(ratings, rank, iterations, lambda, blocks) + } + + /** + * Java stub for Python mllib ALS.trainImplicit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 36853acab5..8b27ecf82c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -578,14 +578,13 @@ object ALS { val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false val alpha = if (args.length >= 8) args(7).toDouble else 1 val blocks = if (args.length == 9) args(8).toInt else -1 - - System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName) - System.setProperty("spark.kryo.referenceTracking", "false") - System.setProperty("spark.kryoserializer.buffer.mb", "8") - System.setProperty("spark.locality.wait", "10000") - val sc = new SparkContext(master, "ALS") + sc.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sc.conf.set("spark.kryo.registrator", classOf[ALSRegistrator].getName) + sc.conf.set("spark.kryo.referenceTracking", "false") + sc.conf.set("spark.kryoserializer.buffer.mb", "8") + sc.conf.set("spark.locality.wait", "10000") + val ratings = sc.textFile(ratingsFile).map { line => val fields = line.split(',') Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) |