aboutsummaryrefslogtreecommitdiff
path: root/docs/ml-guide.md
diff options
context:
space:
mode:
authormartinzapletal <zapletal-martin@email.cz>2015-08-28 21:03:48 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-28 21:03:48 -0700
commite8ea5bafee9ca734edf62021145d0c2d5491cba8 (patch)
tree38f870b1a68b35b68357a40ee3dbf72b1f89aab7 /docs/ml-guide.md
parent2a4e00ca4d4e7a148b4ff8ce0ad1c6d517cee55f (diff)
downloadspark-e8ea5bafee9ca734edf62021145d0c2d5491cba8.tar.gz
spark-e8ea5bafee9ca734edf62021145d0c2d5491cba8.tar.bz2
spark-e8ea5bafee9ca734edf62021145d0c2d5491cba8.zip
[SPARK-9910] [ML] User guide for train validation split
Author: martinzapletal <zapletal-martin@email.cz> Closes #8377 from zapletal-martin/SPARK-9910.
Diffstat (limited to 'docs/ml-guide.md')
-rw-r--r--docs/ml-guide.md117
1 files changed, 117 insertions, 0 deletions
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index ce53400b6e..a92a285f3a 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -872,3 +872,120 @@ jsc.stop();
</div>
</div>
+
+## Example: Model Selection via Train Validation Split
+In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning.
+`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in
+ case of `CrossValidator`. It is therefore less expensive,
+ but will not produce as reliable results when the training dataset is not sufficiently large..
+
+`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter,
+and an `Evaluator`.
+It begins by splitting the dataset into two parts using `trainRatio` parameter
+which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default),
+`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation.
+Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s.
+For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`.
+The `ParamMap` which produces the best evaluation metric is selected as the best option.
+`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.
+
+<div class="codetabs">
+
+<div data-lang="scala" markdown="1">
+{% highlight scala %}
+import org.apache.spark.ml.evaluation.RegressionEvaluator
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
+import org.apache.spark.mllib.util.MLUtils
+
+// Prepare training and test data.
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
+
+val lr = new LinearRegression()
+
+// We use a ParamGridBuilder to construct a grid of parameters to search over.
+// TrainValidationSplit will try all combinations of values and determine best model using
+// the evaluator.
+val paramGrid = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.01))
+ .addGrid(lr.fitIntercept, Array(true, false))
+ .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
+ .build()
+
+// In this case the estimator is simply the linear regression.
+// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+val trainValidationSplit = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(new RegressionEvaluator)
+ .setEstimatorParamMaps(paramGrid)
+
+// 80% of the data will be used for training and the remaining 20% for validation.
+trainValidationSplit.setTrainRatio(0.8)
+
+// Run train validation split, and choose the best set of parameters.
+val model = trainValidationSplit.fit(training)
+
+// Make predictions on test data. model is the model with combination of parameters
+// that performed best.
+model.transform(test)
+ .select("features", "label", "prediction")
+ .show()
+
+{% endhighlight %}
+</div>
+
+<div data-lang="java" markdown="1">
+{% highlight java %}
+import org.apache.spark.ml.evaluation.RegressionEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.regression.LinearRegression;
+import org.apache.spark.ml.tuning.*;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+
+DataFrame data = jsql.createDataFrame(
+ MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
+ LabeledPoint.class);
+
+// Prepare training and test data.
+DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
+DataFrame training = splits[0];
+DataFrame test = splits[1];
+
+LinearRegression lr = new LinearRegression();
+
+// We use a ParamGridBuilder to construct a grid of parameters to search over.
+// TrainValidationSplit will try all combinations of values and determine best model using
+// the evaluator.
+ParamMap[] paramGrid = new ParamGridBuilder()
+ .addGrid(lr.regParam(), new double[] {0.1, 0.01})
+ .addGrid(lr.fitIntercept())
+ .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
+ .build();
+
+// In this case the estimator is simply the linear regression.
+// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(new RegressionEvaluator())
+ .setEstimatorParamMaps(paramGrid);
+
+// 80% of the data will be used for training and the remaining 20% for validation.
+trainValidationSplit.setTrainRatio(0.8);
+
+// Run train validation split, and choose the best set of parameters.
+TrainValidationSplitModel model = trainValidationSplit.fit(training);
+
+// Make predictions on test data. model is the model with combination of parameters
+// that performed best.
+model.transform(test)
+ .select("features", "label", "prediction")
+ .show();
+
+{% endhighlight %}
+</div>
+
+</div>