aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-06 01:28:43 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-06 01:28:43 -0700
commit32cdc815c6fc19b5c8c4eca35f88a61302d67cd5 (patch)
tree0ed046021a78b1f3682b13ca918d65cd08fed9b9 /mllib
parent9f019c7223bb79b8d5cd52980b2723a1601d1134 (diff)
downloadspark-32cdc815c6fc19b5c8c4eca35f88a61302d67cd5.tar.gz
spark-32cdc815c6fc19b5c8c4eca35f88a61302d67cd5.tar.bz2
spark-32cdc815c6fc19b5c8c4eca35f88a61302d67cd5.zip
[SPARK-6940] [MLLIB] Add CrossValidator to Python ML pipeline API
Since CrossValidator is a meta algorithm, we copy the implementation in Python. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #5926 from mengxr/SPARK-6940 and squashes the following commits: 6af181f [Xiangrui Meng] add TODOs 8285134 [Xiangrui Meng] update doc 060f7c3 [Xiangrui Meng] update doctest acac727 [Xiangrui Meng] add keyword args cdddecd [Xiangrui Meng] add CrossValidator in Python
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala7
1 files changed, 5 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index cee2aa6e85..9208127eb1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -52,10 +52,12 @@ private[ml] trait CrossValidatorParams extends Params {
def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
/**
- * param for the evaluator for selection
+ * param for the evaluator used to select hyper-parameters that maximize the cross-validated
+ * metric
* @group param
*/
- val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
+ val evaluator: Param[Evaluator] = new Param(this, "evaluator",
+ "evaluator used to select hyper-parameters that maximize the cross-validated metric")
/** @group getParam */
def getEvaluator: Evaluator = $(evaluator)
@@ -120,6 +122,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
trainingDataset.unpersist()
var i = 0
while (i < numModels) {
+ // TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
metrics(i) += metric