diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-06 01:28:43 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-06 01:28:43 -0700 |
commit | 32cdc815c6fc19b5c8c4eca35f88a61302d67cd5 (patch) | |
tree | 0ed046021a78b1f3682b13ca918d65cd08fed9b9 /mllib | |
parent | 9f019c7223bb79b8d5cd52980b2723a1601d1134 (diff) | |
download | spark-32cdc815c6fc19b5c8c4eca35f88a61302d67cd5.tar.gz spark-32cdc815c6fc19b5c8c4eca35f88a61302d67cd5.tar.bz2 spark-32cdc815c6fc19b5c8c4eca35f88a61302d67cd5.zip |
[SPARK-6940] [MLLIB] Add CrossValidator to Python ML pipeline API
Since CrossValidator is a meta algorithm, we copy the implementation in Python. jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #5926 from mengxr/SPARK-6940 and squashes the following commits:
6af181f [Xiangrui Meng] add TODOs
8285134 [Xiangrui Meng] update doc
060f7c3 [Xiangrui Meng] update doctest
acac727 [Xiangrui Meng] add keyword args
cdddecd [Xiangrui Meng] add CrossValidator in Python
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index cee2aa6e85..9208127eb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -52,10 +52,12 @@ private[ml] trait CrossValidatorParams extends Params { def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) /** - * param for the evaluator for selection + * param for the evaluator used to select hyper-parameters that maximize the cross-validated + * metric * @group param */ - val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + val evaluator: Param[Evaluator] = new Param(this, "evaluator", + "evaluator used to select hyper-parameters that maximize the cross-validated metric") /** @group getParam */ def getEvaluator: Evaluator = $(evaluator) @@ -120,6 +122,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP trainingDataset.unpersist() var i = 0 while (i < numModels) { + // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") metrics(i) += metric |