From 411c04adb596c514f2634efd5f5d126e12b05df7 Mon Sep 17 00:00:00 2001 From: Takuya Kuwahara Date: Wed, 18 May 2016 08:29:47 +0200 Subject: [SPARK-14978][PYSPARK] PySpark TrainValidationSplitModel should support validationMetrics ## What changes were proposed in this pull request? This pull request includes supporting validationMetrics for TrainValidationSplitModel with Python and test for it. ## How was this patch tested? test in `python/pyspark/ml/tests.py` Author: Takuya Kuwahara Closes #12767 from taku-k/spark-14978. --- python/pyspark/ml/tuning.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'python/pyspark/ml/tuning.py') diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 75789c4d09..4f7a6b0f7b 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -379,7 +379,7 @@ class TrainValidationSplit(Estimator, ValidatorParams): seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) - metrics = np.zeros(numModels) + metrics = [0.0] * numModels condition = (df[randCol] >= tRatio) validation = df.filter(condition) train = df.filter(~condition) @@ -392,7 +392,7 @@ class TrainValidationSplit(Estimator, ValidatorParams): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) @since("2.0.0") def copy(self, extra=None): @@ -424,10 +424,12 @@ class TrainValidationSplitModel(Model, ValidatorParams): .. versionadded:: 2.0.0 """ - def __init__(self, bestModel): + def __init__(self, bestModel, validationMetrics=[]): super(TrainValidationSplitModel, self).__init__() #: best model from cross validation self.bestModel = bestModel + #: evaluated validation metrics + self.validationMetrics = validationMetrics def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -439,13 +441,16 @@ class TrainValidationSplitModel(Model, ValidatorParams): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + And, this creates a shallow copy of the validationMetrics. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ if extra is None: extra = dict() - return TrainValidationSplitModel(self.bestModel.copy(extra)) + bestModel = self.bestModel.copy(extra) + validationMetrics = list(self.validationMetrics) + return TrainValidationSplitModel(bestModel, validationMetrics) if __name__ == "__main__": -- cgit v1.2.3