aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-10-16 14:56:50 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-16 14:56:50 -0700
commit091d32c52e9d73da95896016c1d920e89858abfa (patch)
tree904edd29e64b57fa1ab72d3ca37ed2996aa9d1e4 /mllib/src
parent4c589cac4496c6a4bb8485a340bd0641dca13847 (diff)
downloadspark-091d32c52e9d73da95896016c1d920e89858abfa.tar.gz
spark-091d32c52e9d73da95896016c1d920e89858abfa.tar.bz2
spark-091d32c52e9d73da95896016c1d920e89858abfa.zip
[SPARK-3971] [MLLib] [PySpark] hotfix: Customized pickler should work in cluster mode
Customized pickler should be registered before unpickling, but in executor, there is no way to register the picklers before run the tasks. So, we need to register the picklers in the tasks itself, duplicate the javaToPython() and pythonToJava() in MLlib, call SerDe.initialize() before pickling or unpickling. Author: Davies Liu <davies.liu@gmail.com> Closes #2830 from davies/fix_pickle and squashes the following commits: 0c85fb9 [Davies Liu] revert the privacy change 6b94e15 [Davies Liu] use JavaConverters instead of JavaConversions 0f02050 [Davies Liu] hotfix: Customized pickler does not work in cluster
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala52
1 files changed, 47 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f7251e65e0..9a100170b7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.api.python
import java.io.OutputStream
+import java.util.{ArrayList => JArrayList}
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -27,6 +28,7 @@ import net.razorvine.pickle._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
+import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.feature.Word2Vec
@@ -639,13 +641,24 @@ private[spark] object SerDe extends Serializable {
}
}
+ var initialized = false
+ // This should be called before trying to serialize any above classes
+ // In cluster mode, this should be put in the closure
def initialize(): Unit = {
- new DenseVectorPickler().register()
- new DenseMatrixPickler().register()
- new SparseVectorPickler().register()
- new LabeledPointPickler().register()
- new RatingPickler().register()
+ SerDeUtil.initialize()
+ synchronized {
+ if (!initialized) {
+ new DenseVectorPickler().register()
+ new DenseMatrixPickler().register()
+ new SparseVectorPickler().register()
+ new LabeledPointPickler().register()
+ new RatingPickler().register()
+ initialized = true
+ }
+ }
}
+ // will not called in Executor automatically
+ initialize()
def dumps(obj: AnyRef): Array[Byte] = {
new Pickler().dumps(obj)
@@ -659,4 +672,33 @@ private[spark] object SerDe extends Serializable {
def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
}
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter =>
+ initialize() // let it called in executor
+ new PythonRDD.AutoBatchedPickler(iter)
+ }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize() // let it called in executor
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj.asInstanceOf[JArrayList[_]].asScala
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
+ }
}