diff options
author | Davies Liu <davies.liu@gmail.com> | 2014-10-16 14:56:50 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-10-16 14:56:50 -0700 |
commit | 091d32c52e9d73da95896016c1d920e89858abfa (patch) | |
tree | 904edd29e64b57fa1ab72d3ca37ed2996aa9d1e4 /python/pyspark/mllib/clustering.py | |
parent | 4c589cac4496c6a4bb8485a340bd0641dca13847 (diff) | |
download | spark-091d32c52e9d73da95896016c1d920e89858abfa.tar.gz spark-091d32c52e9d73da95896016c1d920e89858abfa.tar.bz2 spark-091d32c52e9d73da95896016c1d920e89858abfa.zip |
[SPARK-3971] [MLLib] [PySpark] hotfix: Customized pickler should work in cluster mode
Customized pickler should be registered before unpickling, but in executor, there is no way to register the picklers before run the tasks.
So, we need to register the picklers in the tasks itself, duplicate the javaToPython() and pythonToJava() in MLlib, call SerDe.initialize() before pickling or unpickling.
Author: Davies Liu <davies.liu@gmail.com>
Closes #2830 from davies/fix_pickle and squashes the following commits:
0c85fb9 [Davies Liu] revert the privacy change
6b94e15 [Davies Liu] use JavaConverters instead of JavaConversions
0f02050 [Davies Liu] hotfix: Customized pickler does not work in cluster
Diffstat (limited to 'python/pyspark/mllib/clustering.py')
-rw-r--r-- | python/pyspark/mllib/clustering.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 12c5602271..5ee7997104 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -17,7 +17,7 @@ from pyspark import SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd __all__ = ['KMeansModel', 'KMeans'] @@ -85,7 +85,7 @@ class KMeans(object): # cache serialized data to avoid objects over head in JVM cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache() model = sc._jvm.PythonMLLibAPI().trainKMeansModel( - cached._to_java_object_rdd(), k, maxIterations, runs, initializationMode) + _to_java_object_rdd(cached), k, maxIterations, runs, initializationMode) bytes = sc._jvm.SerDe.dumps(model.clusterCenters()) centers = ser.loads(str(bytes)) return KMeansModel([c.toArray() for c in centers]) |