aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/ml/tuning.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index b21cf92559..0920ae6ea1 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -33,6 +33,8 @@ __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainVa
class ParamGridBuilder(object):
r"""
+ .. note:: Experimental
+
Builder for a param grid used in grid search-based model selection.
>>> from pyspark.ml.classification import LogisticRegression
@@ -143,6 +145,8 @@ class ValidatorParams(HasSeed):
class CrossValidator(Estimator, ValidatorParams):
"""
+ .. note:: Experimental
+
K-fold cross validation.
>>> from pyspark.ml.classification import LogisticRegression
@@ -260,6 +264,8 @@ class CrossValidator(Estimator, ValidatorParams):
class CrossValidatorModel(Model, ValidatorParams):
"""
+ .. note:: Experimental
+
Model from k-fold cross validation.
.. versionadded:: 1.4.0
@@ -269,6 +275,8 @@ class CrossValidatorModel(Model, ValidatorParams):
super(CrossValidatorModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
+ #: Average cross-validation metrics for each paramMap in
+ #: CrossValidator.estimatorParamMaps, in the corresponding order.
self.avgMetrics = avgMetrics
def _transform(self, dataset):
@@ -294,7 +302,11 @@ class CrossValidatorModel(Model, ValidatorParams):
class TrainValidationSplit(Estimator, ValidatorParams):
"""
- Train-Validation-Split.
+ .. note:: Experimental
+
+ Validation for hyper-parameter tuning. Randomly splits the input dataset into train and
+ validation sets, and uses evaluation metric on the validation set to select the best model.
+ Similar to :class:`CrossValidator`, but only splits the set once.
>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
@@ -405,6 +417,8 @@ class TrainValidationSplit(Estimator, ValidatorParams):
class TrainValidationSplitModel(Model, ValidatorParams):
"""
+ .. note:: Experimental
+
Model from train validation split.
.. versionadded:: 2.0.0