diff options
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r-- | python/pyspark/ml/tuning.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 75789c4d09..4f7a6b0f7b 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -379,7 +379,7 @@ class TrainValidationSplit(Estimator, ValidatorParams): seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) - metrics = np.zeros(numModels) + metrics = [0.0] * numModels condition = (df[randCol] >= tRatio) validation = df.filter(condition) train = df.filter(~condition) @@ -392,7 +392,7 @@ class TrainValidationSplit(Estimator, ValidatorParams): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) @since("2.0.0") def copy(self, extra=None): @@ -424,10 +424,12 @@ class TrainValidationSplitModel(Model, ValidatorParams): .. versionadded:: 2.0.0 """ - def __init__(self, bestModel): + def __init__(self, bestModel, validationMetrics=[]): super(TrainValidationSplitModel, self).__init__() #: best model from cross validation self.bestModel = bestModel + #: evaluated validation metrics + self.validationMetrics = validationMetrics def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -439,13 +441,16 @@ class TrainValidationSplitModel(Model, ValidatorParams): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + And, this creates a shallow copy of the validationMetrics. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ if extra is None: extra = dict() - return TrainValidationSplitModel(self.bestModel.copy(extra)) + bestModel = self.bestModel.copy(extra) + validationMetrics = list(self.validationMetrics) + return TrainValidationSplitModel(bestModel, validationMetrics) if __name__ == "__main__": |