aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/ml/param/__init__.py3
-rw-r--r--python/pyspark/ml/tests.py2
-rw-r--r--python/pyspark/ml/util.py5
-rw-r--r--python/pyspark/ml/wrapper.py2
4 files changed, 8 insertions, 4 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 9f0b063aac..40d8300625 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -485,10 +485,11 @@ class Params(Identifiable):
Changes the uid of this instance. This updates both
the stored uid and the parent uid of params and param maps.
This is used by persistence (loading).
- :param newUid: new uid to use
+ :param newUid: new uid to use, which is converted to unicode
:return: same instance, but with the uid and Param.parent values
updated, including within param maps
"""
+ newUid = unicode(newUid)
self.uid = newUid
newDefaultParamMap = dict()
newParamMap = dict()
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index d595eff5b4..a7a9868bac 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -621,6 +621,8 @@ class PersistenceTest(PySparkTestCase):
lr_path = path + "/lr"
lr.save(lr_path)
lr2 = LinearRegression.load(lr_path)
+ self.assertEqual(lr.uid, lr2.uid)
+ self.assertEqual(type(lr.uid), type(lr2.uid))
self.assertEqual(lr2.uid, lr2.maxIter.parent,
"Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)"
% (lr2.uid, lr2.maxIter.parent))
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 9dfcef0e40..841bfb47e1 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -21,6 +21,7 @@ from functools import wraps
if sys.version > '3':
basestring = str
+ unicode = str
from pyspark import SparkContext, since
from pyspark.mllib.common import inherit_doc
@@ -67,10 +68,10 @@ class Identifiable(object):
@classmethod
def _randomUID(cls):
"""
- Generate a unique id for the object. The default implementation
+ Generate a unique unicode id for the object. The default implementation
concatenates the class name, "_", and 12 random hex chars.
"""
- return cls.__name__ + "_" + uuid.uuid4().hex[12:]
+ return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:])
@inherit_doc
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 055a2816f8..fef626c7fa 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -254,7 +254,7 @@ class JavaModel(JavaTransformer, Model):
"""
super(JavaModel, self).__init__(java_model)
if java_model is not None:
- self.uid = java_model.uid()
+ self._resetUid(java_model.uid())
def copy(self, extra=None):
"""