From 32cdc815c6fc19b5c8c4eca35f88a61302d67cd5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 6 May 2015 01:28:43 -0700 Subject: [SPARK-6940] [MLLIB] Add CrossValidator to Python ML pipeline API Since CrossValidator is a meta algorithm, we copy the implementation in Python. jkbradley Author: Xiangrui Meng Closes #5926 from mengxr/SPARK-6940 and squashes the following commits: 6af181f [Xiangrui Meng] add TODOs 8285134 [Xiangrui Meng] update doc 060f7c3 [Xiangrui Meng] update doctest acac727 [Xiangrui Meng] add keyword args cdddecd [Xiangrui Meng] add CrossValidator in Python --- .../src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index cee2aa6e85..9208127eb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -52,10 +52,12 @@ private[ml] trait CrossValidatorParams extends Params { def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) /** - * param for the evaluator for selection + * param for the evaluator used to select hyper-parameters that maximize the cross-validated + * metric * @group param */ - val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + val evaluator: Param[Evaluator] = new Param(this, "evaluator", + "evaluator used to select hyper-parameters that maximize the cross-validated metric") /** @group getParam */ def getEvaluator: Evaluator = $(evaluator) @@ -120,6 +122,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP trainingDataset.unpersist() var i = 0 while (i < numModels) { + // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") metrics(i) += metric -- cgit v1.2.3