aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tuning.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r--python/pyspark/ml/tuning.py43
1 files changed, 31 insertions, 12 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 86f4dc7368..497841b6c8 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -155,7 +155,7 @@ class CrossValidator(Estimator):
"""
Sets the value of :py:attr:`estimator`.
"""
- self.paramMap[self.estimator] = value
+ self._paramMap[self.estimator] = value
return self
def getEstimator(self):
@@ -168,7 +168,7 @@ class CrossValidator(Estimator):
"""
Sets the value of :py:attr:`estimatorParamMaps`.
"""
- self.paramMap[self.estimatorParamMaps] = value
+ self._paramMap[self.estimatorParamMaps] = value
return self
def getEstimatorParamMaps(self):
@@ -181,7 +181,7 @@ class CrossValidator(Estimator):
"""
Sets the value of :py:attr:`evaluator`.
"""
- self.paramMap[self.evaluator] = value
+ self._paramMap[self.evaluator] = value
return self
def getEvaluator(self):
@@ -194,7 +194,7 @@ class CrossValidator(Estimator):
"""
Sets the value of :py:attr:`numFolds`.
"""
- self.paramMap[self.numFolds] = value
+ self._paramMap[self.numFolds] = value
return self
def getNumFolds(self):
@@ -203,13 +203,12 @@ class CrossValidator(Estimator):
"""
return self.getOrDefault(self.numFolds)
- def fit(self, dataset, params={}):
- paramMap = self.extractParamMap(params)
- est = paramMap[self.estimator]
- epm = paramMap[self.estimatorParamMaps]
+ def _fit(self, dataset):
+ est = self.getOrDefault(self.estimator)
+ epm = self.getOrDefault(self.estimatorParamMaps)
numModels = len(epm)
- eva = paramMap[self.evaluator]
- nFolds = paramMap[self.numFolds]
+ eva = self.getOrDefault(self.evaluator)
+ nFolds = self.getOrDefault(self.numFolds)
h = 1.0 / nFolds
randCol = self.uid + "_rand"
df = dataset.select("*", rand(0).alias(randCol))
@@ -229,6 +228,15 @@ class CrossValidator(Estimator):
bestModel = est.fit(dataset, epm[bestIndex])
return CrossValidatorModel(bestModel)
+ def copy(self, extra={}):
+ newCV = Params.copy(self, extra)
+ if self.isSet(self.estimator):
+ newCV.setEstimator(self.getEstimator().copy(extra))
+ # estimatorParamMaps remain the same
+ if self.isSet(self.evaluator):
+ newCV.setEvaluator(self.getEvaluator().copy(extra))
+ return newCV
+
class CrossValidatorModel(Model):
"""
@@ -240,8 +248,19 @@ class CrossValidatorModel(Model):
#: best model from cross validation
self.bestModel = bestModel
- def transform(self, dataset, params={}):
- return self.bestModel.transform(dataset, params)
+ def _transform(self, dataset):
+ return self.bestModel.transform(dataset)
+
+ def copy(self, extra={}):
+ """
+ Creates a copy of this instance with a randomly generated uid
+ and some extra params. This copies the underlying bestModel,
+ creates a deep copy of the embedded paramMap, and
+ copies the embedded and extra parameters over.
+ :param extra: Extra parameters to copy to the new instance
+ :return: Copy of this instance
+ """
+ return CrossValidatorModel(self.bestModel.copy(extra))
if __name__ == "__main__":