aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tree.py
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-10-16 14:56:50 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-16 14:56:50 -0700
commit091d32c52e9d73da95896016c1d920e89858abfa (patch)
tree904edd29e64b57fa1ab72d3ca37ed2996aa9d1e4 /python/pyspark/mllib/tree.py
parent4c589cac4496c6a4bb8485a340bd0641dca13847 (diff)
downloadspark-091d32c52e9d73da95896016c1d920e89858abfa.tar.gz
spark-091d32c52e9d73da95896016c1d920e89858abfa.tar.bz2
spark-091d32c52e9d73da95896016c1d920e89858abfa.zip
[SPARK-3971] [MLLib] [PySpark] hotfix: Customized pickler should work in cluster mode
Customized pickler should be registered before unpickling, but in executor, there is no way to register the picklers before run the tasks. So, we need to register the picklers in the tasks itself, duplicate the javaToPython() and pythonToJava() in MLlib, call SerDe.initialize() before pickling or unpickling. Author: Davies Liu <davies.liu@gmail.com> Closes #2830 from davies/fix_pickle and squashes the following commits: 0c85fb9 [Davies Liu] revert the privacy change 6b94e15 [Davies Liu] use JavaConverters instead of JavaConversions 0f02050 [Davies Liu] hotfix: Customized pickler does not work in cluster
Diffstat (limited to 'python/pyspark/mllib/tree.py')
-rw-r--r--python/pyspark/mllib/tree.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 5d7abfb96b..0938eebd3a 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -19,7 +19,7 @@ from py4j.java_collections import MapConverter
from pyspark import SparkContext, RDD
from pyspark.serializers import BatchedSerializer, PickleSerializer
-from pyspark.mllib.linalg import Vector, _convert_to_vector
+from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd
from pyspark.mllib.regression import LabeledPoint
__all__ = ['DecisionTreeModel', 'DecisionTree']
@@ -61,8 +61,8 @@ class DecisionTreeModel(object):
return self._sc.parallelize([])
if not isinstance(first[0], Vector):
x = x.map(_convert_to_vector)
- jPred = self._java_model.predict(x._to_java_object_rdd()).toJavaRDD()
- jpyrdd = self._sc._jvm.PythonRDD.javaToPython(jPred)
+ jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD()
+ jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred)
return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024))
else:
@@ -104,7 +104,7 @@ class DecisionTree(object):
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
sc = data.context
- jrdd = data._to_java_object_rdd()
+ jrdd = _to_java_object_rdd(data)
cfiMap = MapConverter().convert(categoricalFeaturesInfo,
sc._gateway._gateway_client)
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(