aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param/__init__.py')
-rw-r--r--python/pyspark/ml/param/__init__.py22
1 files changed, 19 insertions, 3 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index a1265294a1..9f0b063aac 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -26,6 +26,8 @@ import copy
import numpy as np
import warnings
+from py4j.java_gateway import JavaObject
+
from pyspark import since
from pyspark.ml.util import Identifiable
from pyspark.mllib.linalg import DenseVector, Vector
@@ -389,8 +391,8 @@ class Params(Identifiable):
if extra is None:
extra = dict()
that = copy.copy(self)
- that._paramMap = self.extractParamMap(extra)
- return that
+ that._paramMap = {}
+ return self._copyValues(that, extra)
def _shouldOwn(self, param):
"""
@@ -439,12 +441,26 @@ class Params(Identifiable):
self._paramMap[p] = value
return self
+ def _clear(self, param):
+ """
+ Clears a param from the param map if it has been explicitly set.
+ """
+ if self.isSet(param):
+ del self._paramMap[param]
+
def _setDefault(self, **kwargs):
"""
Sets default params.
"""
for param, value in kwargs.items():
- self._defaultParamMap[getattr(self, param)] = value
+ p = getattr(self, param)
+ if value is not None and not isinstance(value, JavaObject):
+ try:
+ value = p.typeConverter(value)
+ except TypeError as e:
+ raise TypeError('Invalid default param value given for param "%s". %s'
+ % (p.name, e))
+ self._defaultParamMap[p] = value
return self
def _copyValues(self, to, extra=None):