diff options
author | Davies Liu <davies@databricks.com> | 2014-10-30 22:25:18 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-10-30 22:25:18 -0700 |
commit | 872fc669b497fb255db3212568f2a14c2ba0d5db (patch) | |
tree | 6dcaa7e0b251fa5f233171e2878a4dc428db2348 /mllib/src/main | |
parent | 0734d09320fe37edd3a02718511cda0bda852478 (diff) | |
download | spark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.gz spark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.bz2 spark-872fc669b497fb255db3212568f2a14c2ba0d5db.zip |
[SPARK-4124] [MLlib] [PySpark] simplify serialization in MLlib Python API
Create several helper functions to call MLlib Java API, convert the arguments to Java type and convert return value to Python object automatically, this simplify serialization in MLlib Python API very much.
After this, the MLlib Python API does not need to deal with serialization details anymore, it's easier to add new API.
cc mengxr
Author: Davies Liu <davies@databricks.com>
Closes #2995 from davies/cleanup and squashes the following commits:
8fa6ec6 [Davies Liu] address comments
16b85a0 [Davies Liu] Merge branch 'master' of github.com:apache/spark into cleanup
43743e5 [Davies Liu] bugfix
731331f [Davies Liu] simplify serialization in MLlib Python API
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 84 |
1 files changed, 45 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 485abe2723..acdc67ddc6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.language.existentials @@ -72,15 +72,11 @@ class PythonMLLibAPI extends Serializable { private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], - initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector] + initialWeights: Vector): JList[Object] = { // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. learner.disableUncachedWarning() val model = learner.run(data.rdd, initialWeights) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(SerDe.dumps(model.weights)) - ret.add(model.intercept: java.lang.Double) - ret + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava } /** @@ -91,10 +87,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) lrAlg.optimizer @@ -113,7 +109,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( lrAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -125,7 +121,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val lassoAlg = new LassoWithSGD() lassoAlg.optimizer .setNumIterations(numIterations) @@ -135,7 +131,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( lassoAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -147,7 +143,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() ridgeAlg.optimizer .setNumIterations(numIterations) @@ -157,7 +153,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( ridgeAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -169,9 +165,9 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) SVMAlg.optimizer @@ -190,7 +186,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( SVMAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -201,10 +197,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) LogRegAlg.optimizer @@ -223,7 +219,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( LogRegAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -231,13 +227,10 @@ class PythonMLLibAPI extends Serializable { */ def trainNaiveBayes( data: JavaRDD[LabeledPoint], - lambda: Double): java.util.List[java.lang.Object] = { + lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(Vectors.dense(model.labels)) - ret.add(Vectors.dense(model.pi)) - ret.add(model.theta) - ret + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + map(_.asInstanceOf[Object]).asJava } /** @@ -260,6 +253,21 @@ class PythonMLLibAPI extends Serializable { } /** + * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python + */ + private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + + def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + + } + + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care * needs to be taken in the Python code to ensure it gets freed on exit; see @@ -271,7 +279,7 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { - ALS.train(ratings.rdd, rank, iterations, lambda, blocks) + new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks)) } /** @@ -287,7 +295,8 @@ class PythonMLLibAPI extends Serializable { lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { - ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) + new MatrixFactorizationModelWrapper( + ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)) } /** @@ -373,19 +382,16 @@ class PythonMLLibAPI extends Serializable { rdd.rdd.map(model.transform) } - def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = { + def findSynonyms(word: String, num: Int): JList[Object] = { val vec = transform(word) findSynonyms(vec, num) } - def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = { + def findSynonyms(vector: Vector, num: Int): JList[Object] = { val result = model.findSynonyms(vector, num) val similarity = Vectors.dense(result.map(_._2)) val words = result.map(_._1) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(words) - ret.add(similarity) - ret + List(words, similarity).map(_.asInstanceOf[Object]).asJava } } @@ -395,13 +401,13 @@ class PythonMLLibAPI extends Serializable { * Extra care needs to be taken in the Python code to ensure it gets freed on exit; * see the Py4J documentation. * @param data Training data - * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + * @param categoricalFeaturesInfo Categorical features info, as Java map */ def trainDecisionTreeModel( data: JavaRDD[LabeledPoint], algoStr: String, numClasses: Int, - categoricalFeaturesInfoJMap: java.util.Map[Int, Int], + categoricalFeaturesInfo: JMap[Int, Int], impurityStr: String, maxDepth: Int, maxBins: Int, @@ -417,7 +423,7 @@ class PythonMLLibAPI extends Serializable { maxDepth = maxDepth, numClassesForClassification = numClasses, maxBins = maxBins, - categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap, + categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) @@ -589,7 +595,7 @@ private[spark] object SerDe extends Serializable { if (objects.length == 0 || objects.length > 3) { out.write(Opcodes.MARK) } - objects.foreach(pickler.save(_)) + objects.foreach(pickler.save) val code = objects.length match { case 1 => Opcodes.TUPLE1 case 2 => Opcodes.TUPLE2 @@ -719,7 +725,7 @@ private[spark] object SerDe extends Serializable { } /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = { + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { rdd.map(x => Array(x._1, x._2)) } |