aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-10-30 22:25:18 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-30 22:25:18 -0700
commit872fc669b497fb255db3212568f2a14c2ba0d5db (patch)
tree6dcaa7e0b251fa5f233171e2878a4dc428db2348 /mllib
parent0734d09320fe37edd3a02718511cda0bda852478 (diff)
downloadspark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.gz
spark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.bz2
spark-872fc669b497fb255db3212568f2a14c2ba0d5db.zip
[SPARK-4124] [MLlib] [PySpark] simplify serialization in MLlib Python API
Create several helper functions to call MLlib Java API, convert the arguments to Java type and convert return value to Python object automatically, this simplify serialization in MLlib Python API very much. After this, the MLlib Python API does not need to deal with serialization details anymore, it's easier to add new API. cc mengxr Author: Davies Liu <davies@databricks.com> Closes #2995 from davies/cleanup and squashes the following commits: 8fa6ec6 [Davies Liu] address comments 16b85a0 [Davies Liu] Merge branch 'master' of github.com:apache/spark into cleanup 43743e5 [Davies Liu] bugfix 731331f [Davies Liu] simplify serialization in MLlib Python API
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala84
1 files changed, 45 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 485abe2723..acdc67ddc6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.api.python
import java.io.OutputStream
-import java.util.{ArrayList => JArrayList}
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -72,15 +72,11 @@ class PythonMLLibAPI extends Serializable {
private def trainRegressionModel(
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
- initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
- val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector]
+ initialWeights: Vector): JList[Object] = {
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
learner.disableUncachedWarning()
val model = learner.run(data.rdd, initialWeights)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(SerDe.dumps(model.weights))
- ret.add(model.intercept: java.lang.Double)
- ret
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
}
/**
@@ -91,10 +87,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
lrAlg.optimizer
@@ -113,7 +109,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
lrAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -125,7 +121,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val lassoAlg = new LassoWithSGD()
lassoAlg.optimizer
.setNumIterations(numIterations)
@@ -135,7 +131,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
lassoAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -147,7 +143,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val ridgeAlg = new RidgeRegressionWithSGD()
ridgeAlg.optimizer
.setNumIterations(numIterations)
@@ -157,7 +153,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
ridgeAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -169,9 +165,9 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
SVMAlg.optimizer
@@ -190,7 +186,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
SVMAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -201,10 +197,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
LogRegAlg.optimizer
@@ -223,7 +219,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
LogRegAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -231,13 +227,10 @@ class PythonMLLibAPI extends Serializable {
*/
def trainNaiveBayes(
data: JavaRDD[LabeledPoint],
- lambda: Double): java.util.List[java.lang.Object] = {
+ lambda: Double): JList[Object] = {
val model = NaiveBayes.train(data.rdd, lambda)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(Vectors.dense(model.labels))
- ret.add(Vectors.dense(model.pi))
- ret.add(model.theta)
- ret
+ List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta).
+ map(_.asInstanceOf[Object]).asJava
}
/**
@@ -260,6 +253,21 @@ class PythonMLLibAPI extends Serializable {
}
/**
+ * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python
+ */
+ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel)
+ extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) {
+
+ def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] =
+ predict(SerDe.asTupleRDD(userAndProducts.rdd))
+
+ def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+
+ def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+
+ }
+
+ /**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
* needs to be taken in the Python code to ensure it gets freed on exit; see
@@ -271,7 +279,7 @@ class PythonMLLibAPI extends Serializable {
iterations: Int,
lambda: Double,
blocks: Int): MatrixFactorizationModel = {
- ALS.train(ratings.rdd, rank, iterations, lambda, blocks)
+ new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks))
}
/**
@@ -287,7 +295,8 @@ class PythonMLLibAPI extends Serializable {
lambda: Double,
blocks: Int,
alpha: Double): MatrixFactorizationModel = {
- ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)
+ new MatrixFactorizationModelWrapper(
+ ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha))
}
/**
@@ -373,19 +382,16 @@ class PythonMLLibAPI extends Serializable {
rdd.rdd.map(model.transform)
}
- def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
+ def findSynonyms(word: String, num: Int): JList[Object] = {
val vec = transform(word)
findSynonyms(vec, num)
}
- def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
+ def findSynonyms(vector: Vector, num: Int): JList[Object] = {
val result = model.findSynonyms(vector, num)
val similarity = Vectors.dense(result.map(_._2))
val words = result.map(_._1)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(words)
- ret.add(similarity)
- ret
+ List(words, similarity).map(_.asInstanceOf[Object]).asJava
}
}
@@ -395,13 +401,13 @@ class PythonMLLibAPI extends Serializable {
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
* see the Py4J documentation.
* @param data Training data
- * @param categoricalFeaturesInfoJMap Categorical features info, as Java map
+ * @param categoricalFeaturesInfo Categorical features info, as Java map
*/
def trainDecisionTreeModel(
data: JavaRDD[LabeledPoint],
algoStr: String,
numClasses: Int,
- categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
+ categoricalFeaturesInfo: JMap[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int,
@@ -417,7 +423,7 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
- categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
+ categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)
@@ -589,7 +595,7 @@ private[spark] object SerDe extends Serializable {
if (objects.length == 0 || objects.length > 3) {
out.write(Opcodes.MARK)
}
- objects.foreach(pickler.save(_))
+ objects.foreach(pickler.save)
val code = objects.length match {
case 1 => Opcodes.TUPLE1
case 2 => Opcodes.TUPLE2
@@ -719,7 +725,7 @@ private[spark] object SerDe extends Serializable {
}
/* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
- def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = {
+ def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
rdd.map(x => Array(x._1, x._2))
}