aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkrishnakalyan3 <krishnakalyan3@gmail.com>2016-07-27 15:37:38 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-07-27 15:37:38 +0200
commit7e8279fde176b08687adf2b410693b35cfbd4b46 (patch)
tree9342eb9868f5f5a1ed7358a8883605c31a15d24d
parent045fc3606698b017a4addf5277808883e6fe76b6 (diff)
downloadspark-7e8279fde176b08687adf2b410693b35cfbd4b46.tar.gz
spark-7e8279fde176b08687adf2b410693b35cfbd4b46.tar.bz2
spark-7e8279fde176b08687adf2b410693b35cfbd4b46.zip
[SPARK-15254][DOC] Improve ML pipeline Cross Validation Scaladoc & PyDoc
## What changes were proposed in this pull request? Updated ML pipeline Cross Validation Scaladoc & PyDoc. ## How was this patch tested? Documentation update (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: krishnakalyan3 <krishnakalyan3@gmail.com> Closes #13894 from krishnakalyan3/kfold-cv.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala10
-rw-r--r--python/pyspark/ml/tuning.py13
2 files changed, 19 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 520557849b..6ea52ef7f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -55,7 +55,11 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
}
/**
- * K-fold cross validation.
+ * K-fold cross validation performs model selection by splitting the dataset into a set of
+ * non-overlapping randomly partitioned folds which are used as separate training and test datasets
+ * e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
+ * each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
+ * test set exactly once.
*/
@Since("1.2.0")
class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@@ -188,7 +192,9 @@ object CrossValidator extends MLReadable[CrossValidator] {
}
/**
- * Model from k-fold cross validation.
+ * CrossValidatorModel contains the model with the highest average cross-validation
+ * metric across folds and uses this model to transform input data. CrossValidatorModel
+ * also tracks the metrics for each param map evaluated.
*
* @param bestModel The best model selected from k-fold cross validation.
* @param avgMetrics Average cross-validation metrics for each paramMap in
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 298314d46c..7f967e5463 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -143,7 +143,13 @@ class ValidatorParams(HasSeed):
class CrossValidator(Estimator, ValidatorParams):
"""
- K-fold cross validation.
+
+ K-fold cross validation performs model selection by splitting the dataset into a set of
+ non-overlapping randomly partitioned folds which are used as separate training and test datasets
+ e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
+ each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
+ test set exactly once.
+
>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
@@ -260,7 +266,10 @@ class CrossValidator(Estimator, ValidatorParams):
class CrossValidatorModel(Model, ValidatorParams):
"""
- Model from k-fold cross validation.
+
+ CrossValidatorModel contains the model with the highest average cross-validation
+ metric across folds and uses this model to transform input data. CrossValidatorModel
+ also tracks the metrics for each param map evaluated.
.. versionadded:: 1.4.0
"""