aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/common.py')
-rw-r--r--python/pyspark/mllib/common.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index a539d2f284..ba60589788 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -15,6 +15,11 @@
# limitations under the License.
#
+import sys
+if sys.version >= '3':
+ long = int
+ unicode = str
+
import py4j.protocol
from py4j.protocol import Py4JJavaError
from py4j.java_gateway import JavaObject
@@ -36,7 +41,7 @@ _float_str_mapping = {
def _new_smart_decode(obj):
if isinstance(obj, float):
- s = unicode(obj)
+ s = str(obj)
return _float_str_mapping.get(s, s)
return _old_smart_decode(obj)
@@ -74,15 +79,15 @@ def _py2java(sc, obj):
obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
elif isinstance(obj, JavaObject):
pass
- elif isinstance(obj, (int, long, float, bool, basestring)):
+ elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
pass
else:
- bytes = bytearray(PickleSerializer().dumps(obj))
- obj = sc._jvm.SerDe.loads(bytes)
+ data = bytearray(PickleSerializer().dumps(obj))
+ obj = sc._jvm.SerDe.loads(data)
return obj
-def _java2py(sc, r):
+def _java2py(sc, r, encoding="bytes"):
if isinstance(r, JavaObject):
clsName = r.getClass().getSimpleName()
# convert RDD into JavaRDD
@@ -102,8 +107,8 @@ def _java2py(sc, r):
except Py4JJavaError:
pass # not pickable
- if isinstance(r, bytearray):
- r = PickleSerializer().loads(str(r))
+ if isinstance(r, (bytearray, bytes)):
+ r = PickleSerializer().loads(bytes(r), encoding=encoding)
return r