aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tuning.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r--python/pyspark/ml/tuning.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 298314d46c..7f967e5463 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -143,7 +143,13 @@ class ValidatorParams(HasSeed):
class CrossValidator(Estimator, ValidatorParams):
"""
- K-fold cross validation.
+
+ K-fold cross validation performs model selection by splitting the dataset into a set of
+ non-overlapping randomly partitioned folds which are used as separate training and test datasets
+ e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
+ each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
+ test set exactly once.
+
>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
@@ -260,7 +266,10 @@ class CrossValidator(Estimator, ValidatorParams):
class CrossValidatorModel(Model, ValidatorParams):
"""
- Model from k-fold cross validation.
+
+ CrossValidatorModel contains the model with the highest average cross-validation
+ metric across folds and uses this model to transform input data. CrossValidatorModel
+ also tracks the metrics for each param map evaluated.
.. versionadded:: 1.4.0
"""