diff options
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 84 |
1 files changed, 45 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 485abe2723..acdc67ddc6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.language.existentials @@ -72,15 +72,11 @@ class PythonMLLibAPI extends Serializable { private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], - initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector] + initialWeights: Vector): JList[Object] = { // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. learner.disableUncachedWarning() val model = learner.run(data.rdd, initialWeights) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(SerDe.dumps(model.weights)) - ret.add(model.intercept: java.lang.Double) - ret + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava } /** @@ -91,10 +87,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) lrAlg.optimizer @@ -113,7 +109,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( lrAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -125,7 +121,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val lassoAlg = new LassoWithSGD() lassoAlg.optimizer .setNumIterations(numIterations) @@ -135,7 +131,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( lassoAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -147,7 +143,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() ridgeAlg.optimizer .setNumIterations(numIterations) @@ -157,7 +153,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( ridgeAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -169,9 +165,9 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) SVMAlg.optimizer @@ -190,7 +186,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( SVMAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -201,10 +197,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) LogRegAlg.optimizer @@ -223,7 +219,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( LogRegAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -231,13 +227,10 @@ class PythonMLLibAPI extends Serializable { */ def trainNaiveBayes( data: JavaRDD[LabeledPoint], - lambda: Double): java.util.List[java.lang.Object] = { + lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(Vectors.dense(model.labels)) - ret.add(Vectors.dense(model.pi)) - ret.add(model.theta) - ret + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + map(_.asInstanceOf[Object]).asJava } /** @@ -260,6 +253,21 @@ class PythonMLLibAPI extends Serializable { } /** + * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python + */ + private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + + def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + + } + + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care * needs to be taken in the Python code to ensure it gets freed on exit; see @@ -271,7 +279,7 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { - ALS.train(ratings.rdd, rank, iterations, lambda, blocks) + new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks)) } /** @@ -287,7 +295,8 @@ class PythonMLLibAPI extends Serializable { lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { - ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) + new MatrixFactorizationModelWrapper( + ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)) } /** @@ -373,19 +382,16 @@ class PythonMLLibAPI extends Serializable { rdd.rdd.map(model.transform) } - def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = { + def findSynonyms(word: String, num: Int): JList[Object] = { val vec = transform(word) findSynonyms(vec, num) } - def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = { + def findSynonyms(vector: Vector, num: Int): JList[Object] = { val result = model.findSynonyms(vector, num) val similarity = Vectors.dense(result.map(_._2)) val words = result.map(_._1) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(words) - ret.add(similarity) - ret + List(words, similarity).map(_.asInstanceOf[Object]).asJava } } @@ -395,13 +401,13 @@ class PythonMLLibAPI extends Serializable { * Extra care needs to be taken in the Python code to ensure it gets freed on exit; * see the Py4J documentation. * @param data Training data - * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + * @param categoricalFeaturesInfo Categorical features info, as Java map */ def trainDecisionTreeModel( data: JavaRDD[LabeledPoint], algoStr: String, numClasses: Int, - categoricalFeaturesInfoJMap: java.util.Map[Int, Int], + categoricalFeaturesInfo: JMap[Int, Int], impurityStr: String, maxDepth: Int, maxBins: Int, @@ -417,7 +423,7 @@ class PythonMLLibAPI extends Serializable { maxDepth = maxDepth, numClassesForClassification = numClasses, maxBins = maxBins, - categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap, + categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) @@ -589,7 +595,7 @@ private[spark] object SerDe extends Serializable { if (objects.length == 0 || objects.length > 3) { out.write(Opcodes.MARK) } - objects.foreach(pickler.save(_)) + objects.foreach(pickler.save) val code = objects.length match { case 1 => Opcodes.TUPLE1 case 2 => Opcodes.TUPLE2 @@ -719,7 +725,7 @@ private[spark] object SerDe extends Serializable { } /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = { + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { rdd.map(x => Array(x._1, x._2)) } |