aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tuning.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r--python/pyspark/ml/tuning.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 75789c4d09..4f7a6b0f7b 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -379,7 +379,7 @@ class TrainValidationSplit(Estimator, ValidatorParams):
seed = self.getOrDefault(self.seed)
randCol = self.uid + "_rand"
df = dataset.select("*", rand(seed).alias(randCol))
- metrics = np.zeros(numModels)
+ metrics = [0.0] * numModels
condition = (df[randCol] >= tRatio)
validation = df.filter(condition)
train = df.filter(~condition)
@@ -392,7 +392,7 @@ class TrainValidationSplit(Estimator, ValidatorParams):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return self._copyValues(TrainValidationSplitModel(bestModel))
+ return self._copyValues(TrainValidationSplitModel(bestModel, metrics))
@since("2.0.0")
def copy(self, extra=None):
@@ -424,10 +424,12 @@ class TrainValidationSplitModel(Model, ValidatorParams):
.. versionadded:: 2.0.0
"""
- def __init__(self, bestModel):
+ def __init__(self, bestModel, validationMetrics=[]):
super(TrainValidationSplitModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
+ #: evaluated validation metrics
+ self.validationMetrics = validationMetrics
def _transform(self, dataset):
return self.bestModel.transform(dataset)
@@ -439,13 +441,16 @@ class TrainValidationSplitModel(Model, ValidatorParams):
and some extra params. This copies the underlying bestModel,
creates a deep copy of the embedded paramMap, and
copies the embedded and extra parameters over.
+ And, this creates a shallow copy of the validationMetrics.
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
"""
if extra is None:
extra = dict()
- return TrainValidationSplitModel(self.bestModel.copy(extra))
+ bestModel = self.bestModel.copy(extra)
+ validationMetrics = list(self.validationMetrics)
+ return TrainValidationSplitModel(bestModel, validationMetrics)
if __name__ == "__main__":