aboutsummaryrefslogtreecommitdiff
path: root/docs/ml-guide.md
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-27 21:44:06 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-27 21:44:06 -0700
commit30734d45fbbb269437c062241a9161e198805a76 (patch)
treef6ae8a50ab46b77e1a2bbde3af0989c1550a7737 /docs/ml-guide.md
parent1f90c5e2198bcf49e115d97ec300c17c1be4dcb4 (diff)
downloadspark-30734d45fbbb269437c062241a9161e198805a76.tar.gz
spark-30734d45fbbb269437c062241a9161e198805a76.tar.bz2
spark-30734d45fbbb269437c062241a9161e198805a76.zip
[SPARK-9911] [DOC] [ML] Update Userguide for Evaluator
I added a small note about the different types of evaluator and the metrics used. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #8304 from MechCoder/multiclass_evaluator.
Diffstat (limited to 'docs/ml-guide.md')
-rw-r--r--docs/ml-guide.md13
1 files changed, 13 insertions, 0 deletions
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index de8fead352..01bf5ee18e 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -643,6 +643,13 @@ An important task in ML is *model selection*, or using data to find the best mod
Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator).
`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing.
`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`.
+
+The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator)
+for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator)
+for binary data or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator)
+for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the setMetric
+method in each of these evaluators.
+
The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model.
`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.
@@ -708,9 +715,12 @@ val pipeline = new Pipeline()
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
// This will allow us to jointly choose parameters for all Pipeline stages.
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric
+// used is areaUnderROC.
val crossval = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator)
+
// We use a ParamGridBuilder to construct a grid of parameters to search over.
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
@@ -831,9 +841,12 @@ Pipeline pipeline = new Pipeline()
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
// This will allow us to jointly choose parameters for all Pipeline stages.
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric
+// used is areaUnderROC.
CrossValidator crossval = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator());
+
// We use a ParamGridBuilder to construct a grid of parameters to search over.
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.