aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tuning.py
diff options
context:
space:
mode:
authorTakuya Kuwahara <taakuu19@gmail.com>2016-05-18 08:29:47 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-18 08:29:47 +0200
commit411c04adb596c514f2634efd5f5d126e12b05df7 (patch)
tree80c10aa79cebc920cb9481f5b4bdd3c866e23313 /python/pyspark/ml/tuning.py
parent2a5db9c140b9d60a5ec91018be19bec7b80850ee (diff)
downloadspark-411c04adb596c514f2634efd5f5d126e12b05df7.tar.gz
spark-411c04adb596c514f2634efd5f5d126e12b05df7.tar.bz2
spark-411c04adb596c514f2634efd5f5d126e12b05df7.zip
[SPARK-14978][PYSPARK] PySpark TrainValidationSplitModel should support validationMetrics
## What changes were proposed in this pull request? This pull request includes supporting validationMetrics for TrainValidationSplitModel with Python and test for it. ## How was this patch tested? test in `python/pyspark/ml/tests.py` Author: Takuya Kuwahara <taakuu19@gmail.com> Closes #12767 from taku-k/spark-14978.
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r--python/pyspark/ml/tuning.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 75789c4d09..4f7a6b0f7b 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -379,7 +379,7 @@ class TrainValidationSplit(Estimator, ValidatorParams):
seed = self.getOrDefault(self.seed)
randCol = self.uid + "_rand"
df = dataset.select("*", rand(seed).alias(randCol))
- metrics = np.zeros(numModels)
+ metrics = [0.0] * numModels
condition = (df[randCol] >= tRatio)
validation = df.filter(condition)
train = df.filter(~condition)
@@ -392,7 +392,7 @@ class TrainValidationSplit(Estimator, ValidatorParams):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return self._copyValues(TrainValidationSplitModel(bestModel))
+ return self._copyValues(TrainValidationSplitModel(bestModel, metrics))
@since("2.0.0")
def copy(self, extra=None):
@@ -424,10 +424,12 @@ class TrainValidationSplitModel(Model, ValidatorParams):
.. versionadded:: 2.0.0
"""
- def __init__(self, bestModel):
+ def __init__(self, bestModel, validationMetrics=[]):
super(TrainValidationSplitModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
+ #: evaluated validation metrics
+ self.validationMetrics = validationMetrics
def _transform(self, dataset):
return self.bestModel.transform(dataset)
@@ -439,13 +441,16 @@ class TrainValidationSplitModel(Model, ValidatorParams):
and some extra params. This copies the underlying bestModel,
creates a deep copy of the embedded paramMap, and
copies the embedded and extra parameters over.
+ And, this creates a shallow copy of the validationMetrics.
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
"""
if extra is None:
extra = dict()
- return TrainValidationSplitModel(self.bestModel.copy(extra))
+ bestModel = self.bestModel.copy(extra)
+ validationMetrics = list(self.validationMetrics)
+ return TrainValidationSplitModel(bestModel, validationMetrics)
if __name__ == "__main__":