aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormartinzapletal <zapletal-martin@email.cz>2015-08-28 21:03:48 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-28 21:03:59 -0700
commit69d856527d50d01624feaf1461af2d7bff03a668 (patch)
treee1500e4a69788e44eb388a888bb46c12e519ac39
parentb7aab1d1838bdffdf29923fc0f18eb04e582957e (diff)
downloadspark-69d856527d50d01624feaf1461af2d7bff03a668.tar.gz
spark-69d856527d50d01624feaf1461af2d7bff03a668.tar.bz2
spark-69d856527d50d01624feaf1461af2d7bff03a668.zip
[SPARK-9910] [ML] User guide for train validation split
Author: martinzapletal <zapletal-martin@email.cz> Closes #8377 from zapletal-martin/SPARK-9910. (cherry picked from commit e8ea5bafee9ca734edf62021145d0c2d5491cba8) Signed-off-by: Xiangrui Meng <meng@databricks.com>
-rw-r--r--docs/ml-guide.md117
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java90
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala80
3 files changed, 287 insertions, 0 deletions
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index ce53400b6e..a92a285f3a 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -872,3 +872,120 @@ jsc.stop();
</div>
</div>
+
+## Example: Model Selection via Train Validation Split
+In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning.
+`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in
+ case of `CrossValidator`. It is therefore less expensive,
+ but will not produce as reliable results when the training dataset is not sufficiently large..
+
+`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter,
+and an `Evaluator`.
+It begins by splitting the dataset into two parts using `trainRatio` parameter
+which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default),
+`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation.
+Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s.
+For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`.
+The `ParamMap` which produces the best evaluation metric is selected as the best option.
+`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.
+
+<div class="codetabs">
+
+<div data-lang="scala" markdown="1">
+{% highlight scala %}
+import org.apache.spark.ml.evaluation.RegressionEvaluator
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
+import org.apache.spark.mllib.util.MLUtils
+
+// Prepare training and test data.
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
+
+val lr = new LinearRegression()
+
+// We use a ParamGridBuilder to construct a grid of parameters to search over.
+// TrainValidationSplit will try all combinations of values and determine best model using
+// the evaluator.
+val paramGrid = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.01))
+ .addGrid(lr.fitIntercept, Array(true, false))
+ .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
+ .build()
+
+// In this case the estimator is simply the linear regression.
+// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+val trainValidationSplit = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(new RegressionEvaluator)
+ .setEstimatorParamMaps(paramGrid)
+
+// 80% of the data will be used for training and the remaining 20% for validation.
+trainValidationSplit.setTrainRatio(0.8)
+
+// Run train validation split, and choose the best set of parameters.
+val model = trainValidationSplit.fit(training)
+
+// Make predictions on test data. model is the model with combination of parameters
+// that performed best.
+model.transform(test)
+ .select("features", "label", "prediction")
+ .show()
+
+{% endhighlight %}
+</div>
+
+<div data-lang="java" markdown="1">
+{% highlight java %}
+import org.apache.spark.ml.evaluation.RegressionEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.regression.LinearRegression;
+import org.apache.spark.ml.tuning.*;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+
+DataFrame data = jsql.createDataFrame(
+ MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
+ LabeledPoint.class);
+
+// Prepare training and test data.
+DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
+DataFrame training = splits[0];
+DataFrame test = splits[1];
+
+LinearRegression lr = new LinearRegression();
+
+// We use a ParamGridBuilder to construct a grid of parameters to search over.
+// TrainValidationSplit will try all combinations of values and determine best model using
+// the evaluator.
+ParamMap[] paramGrid = new ParamGridBuilder()
+ .addGrid(lr.regParam(), new double[] {0.1, 0.01})
+ .addGrid(lr.fitIntercept())
+ .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
+ .build();
+
+// In this case the estimator is simply the linear regression.
+// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(new RegressionEvaluator())
+ .setEstimatorParamMaps(paramGrid);
+
+// 80% of the data will be used for training and the remaining 20% for validation.
+trainValidationSplit.setTrainRatio(0.8);
+
+// Run train validation split, and choose the best set of parameters.
+TrainValidationSplitModel model = trainValidationSplit.fit(training);
+
+// Make predictions on test data. model is the model with combination of parameters
+// that performed best.
+model.transform(test)
+ .select("features", "label", "prediction")
+ .show();
+
+{% endhighlight %}
+</div>
+
+</div>
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
new file mode 100644
index 0000000000..23f834ab43
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.evaluation.RegressionEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.regression.LinearRegression;
+import org.apache.spark.ml.tuning.*;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+/**
+ * A simple example demonstrating model selection using TrainValidationSplit.
+ *
+ * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample}
+ * using linear regression.
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.JavaTrainValidationSplitExample
+ * }}}
+ */
+public class JavaTrainValidationSplitExample {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ DataFrame data = jsql.createDataFrame(
+ MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
+ LabeledPoint.class);
+
+ // Prepare training and test data.
+ DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
+ DataFrame training = splits[0];
+ DataFrame test = splits[1];
+
+ LinearRegression lr = new LinearRegression();
+
+ // We use a ParamGridBuilder to construct a grid of parameters to search over.
+ // TrainValidationSplit will try all combinations of values and determine best model using
+ // the evaluator.
+ ParamMap[] paramGrid = new ParamGridBuilder()
+ .addGrid(lr.regParam(), new double[] {0.1, 0.01})
+ .addGrid(lr.fitIntercept())
+ .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
+ .build();
+
+ // In this case the estimator is simply the linear regression.
+ // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(new RegressionEvaluator())
+ .setEstimatorParamMaps(paramGrid);
+
+ // 80% of the data will be used for training and the remaining 20% for validation.
+ trainValidationSplit.setTrainRatio(0.8);
+
+ // Run train validation split, and choose the best set of parameters.
+ TrainValidationSplitModel model = trainValidationSplit.fit(training);
+
+ // Make predictions on test data. model is the model with combination of parameters
+ // that performed best.
+ model.transform(test)
+ .select("features", "label", "prediction")
+ .show();
+
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
new file mode 100644
index 0000000000..1abdf219b1
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import org.apache.spark.ml.evaluation.RegressionEvaluator
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.{SparkConf, SparkContext}
+
+/**
+ * A simple example demonstrating model selection using TrainValidationSplit.
+ *
+ * The example is based on [[SimpleParamsExample]] using linear regression.
+ * Run with
+ * {{{
+ * bin/run-example ml.TrainValidationSplitExample
+ * }}}
+ */
+object TrainValidationSplitExample {
+
+ def main(args: Array[String]): Unit = {
+ val conf = new SparkConf().setAppName("TrainValidationSplitExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Prepare training and test data.
+ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+ val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
+
+ val lr = new LinearRegression()
+
+ // We use a ParamGridBuilder to construct a grid of parameters to search over.
+ // TrainValidationSplit will try all combinations of values and determine best model using
+ // the evaluator.
+ val paramGrid = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.01))
+ .addGrid(lr.fitIntercept, Array(true, false))
+ .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
+ .build()
+
+ // In this case the estimator is simply the linear regression.
+ // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ val trainValidationSplit = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(new RegressionEvaluator)
+ .setEstimatorParamMaps(paramGrid)
+
+ // 80% of the data will be used for training and the remaining 20% for validation.
+ trainValidationSplit.setTrainRatio(0.8)
+
+ // Run train validation split, and choose the best set of parameters.
+ val model = trainValidationSplit.fit(training)
+
+ // Make predictions on test data. model is the model with combination of parameters
+ // that performed best.
+ model.transform(test)
+ .select("features", "label", "prediction")
+ .show()
+
+ sc.stop()
+ }
+}