diff options
author | Takuya Kuwahara <taakuu19@gmail.com> | 2016-05-18 08:29:47 +0200 |
---|---|---|
committer | Nick Pentreath <nickp@za.ibm.com> | 2016-05-18 08:29:47 +0200 |
commit | 411c04adb596c514f2634efd5f5d126e12b05df7 (patch) | |
tree | 80c10aa79cebc920cb9481f5b4bdd3c866e23313 /python/pyspark/ml/tuning.py | |
parent | 2a5db9c140b9d60a5ec91018be19bec7b80850ee (diff) | |
download | spark-411c04adb596c514f2634efd5f5d126e12b05df7.tar.gz spark-411c04adb596c514f2634efd5f5d126e12b05df7.tar.bz2 spark-411c04adb596c514f2634efd5f5d126e12b05df7.zip |
[SPARK-14978][PYSPARK] PySpark TrainValidationSplitModel should support validationMetrics
## What changes were proposed in this pull request?
This pull request includes supporting validationMetrics for TrainValidationSplitModel with Python and test for it.
## How was this patch tested?
test in `python/pyspark/ml/tests.py`
Author: Takuya Kuwahara <taakuu19@gmail.com>
Closes #12767 from taku-k/spark-14978.
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r-- | python/pyspark/ml/tuning.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 75789c4d09..4f7a6b0f7b 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -379,7 +379,7 @@ class TrainValidationSplit(Estimator, ValidatorParams): seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) - metrics = np.zeros(numModels) + metrics = [0.0] * numModels condition = (df[randCol] >= tRatio) validation = df.filter(condition) train = df.filter(~condition) @@ -392,7 +392,7 @@ class TrainValidationSplit(Estimator, ValidatorParams): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) @since("2.0.0") def copy(self, extra=None): @@ -424,10 +424,12 @@ class TrainValidationSplitModel(Model, ValidatorParams): .. versionadded:: 2.0.0 """ - def __init__(self, bestModel): + def __init__(self, bestModel, validationMetrics=[]): super(TrainValidationSplitModel, self).__init__() #: best model from cross validation self.bestModel = bestModel + #: evaluated validation metrics + self.validationMetrics = validationMetrics def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -439,13 +441,16 @@ class TrainValidationSplitModel(Model, ValidatorParams): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + And, this creates a shallow copy of the validationMetrics. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ if extra is None: extra = dict() - return TrainValidationSplitModel(self.bestModel.copy(extra)) + bestModel = self.bestModel.copy(extra) + validationMetrics = list(self.validationMetrics) + return TrainValidationSplitModel(bestModel, validationMetrics) if __name__ == "__main__": |