diff options
Diffstat (limited to 'python/pyspark/ml/param/__init__.py')
-rw-r--r-- | python/pyspark/ml/param/__init__.py | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index a1265294a1..9f0b063aac 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -26,6 +26,8 @@ import copy import numpy as np import warnings +from py4j.java_gateway import JavaObject + from pyspark import since from pyspark.ml.util import Identifiable from pyspark.mllib.linalg import DenseVector, Vector @@ -389,8 +391,8 @@ class Params(Identifiable): if extra is None: extra = dict() that = copy.copy(self) - that._paramMap = self.extractParamMap(extra) - return that + that._paramMap = {} + return self._copyValues(that, extra) def _shouldOwn(self, param): """ @@ -439,12 +441,26 @@ class Params(Identifiable): self._paramMap[p] = value return self + def _clear(self, param): + """ + Clears a param from the param map if it has been explicitly set. + """ + if self.isSet(param): + del self._paramMap[param] + def _setDefault(self, **kwargs): """ Sets default params. """ for param, value in kwargs.items(): - self._defaultParamMap[getattr(self, param)] = value + p = getattr(self, param) + if value is not None and not isinstance(value, JavaObject): + try: + value = p.typeConverter(value) + except TypeError as e: + raise TypeError('Invalid default param value given for param "%s". %s' + % (p.name, e)) + self._defaultParamMap[p] = value return self def _copyValues(self, to, extra=None): |