aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authormartinzapletal <zapletal-martin@email.cz>2015-08-28 21:03:48 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-28 21:03:48 -0700
commite8ea5bafee9ca734edf62021145d0c2d5491cba8 (patch)
tree38f870b1a68b35b68357a40ee3dbf72b1f89aab7 /examples
parent2a4e00ca4d4e7a148b4ff8ce0ad1c6d517cee55f (diff)
downloadspark-e8ea5bafee9ca734edf62021145d0c2d5491cba8.tar.gz
spark-e8ea5bafee9ca734edf62021145d0c2d5491cba8.tar.bz2
spark-e8ea5bafee9ca734edf62021145d0c2d5491cba8.zip
[SPARK-9910] [ML] User guide for train validation split
Author: martinzapletal <zapletal-martin@email.cz> Closes #8377 from zapletal-martin/SPARK-9910.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java90
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala80
2 files changed, 170 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
new file mode 100644
index 0000000000..23f834ab43
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.evaluation.RegressionEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.regression.LinearRegression;
+import org.apache.spark.ml.tuning.*;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+/**
+ * A simple example demonstrating model selection using TrainValidationSplit.
+ *
+ * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample}
+ * using linear regression.
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.JavaTrainValidationSplitExample
+ * }}}
+ */
+public class JavaTrainValidationSplitExample {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ DataFrame data = jsql.createDataFrame(
+ MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
+ LabeledPoint.class);
+
+ // Prepare training and test data.
+ DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
+ DataFrame training = splits[0];
+ DataFrame test = splits[1];
+
+ LinearRegression lr = new LinearRegression();
+
+ // We use a ParamGridBuilder to construct a grid of parameters to search over.
+ // TrainValidationSplit will try all combinations of values and determine best model using
+ // the evaluator.
+ ParamMap[] paramGrid = new ParamGridBuilder()
+ .addGrid(lr.regParam(), new double[] {0.1, 0.01})
+ .addGrid(lr.fitIntercept())
+ .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
+ .build();
+
+ // In this case the estimator is simply the linear regression.
+ // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(new RegressionEvaluator())
+ .setEstimatorParamMaps(paramGrid);
+
+ // 80% of the data will be used for training and the remaining 20% for validation.
+ trainValidationSplit.setTrainRatio(0.8);
+
+ // Run train validation split, and choose the best set of parameters.
+ TrainValidationSplitModel model = trainValidationSplit.fit(training);
+
+ // Make predictions on test data. model is the model with combination of parameters
+ // that performed best.
+ model.transform(test)
+ .select("features", "label", "prediction")
+ .show();
+
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
new file mode 100644
index 0000000000..1abdf219b1
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import org.apache.spark.ml.evaluation.RegressionEvaluator
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.{SparkConf, SparkContext}
+
+/**
+ * A simple example demonstrating model selection using TrainValidationSplit.
+ *
+ * The example is based on [[SimpleParamsExample]] using linear regression.
+ * Run with
+ * {{{
+ * bin/run-example ml.TrainValidationSplitExample
+ * }}}
+ */
+object TrainValidationSplitExample {
+
+ def main(args: Array[String]): Unit = {
+ val conf = new SparkConf().setAppName("TrainValidationSplitExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Prepare training and test data.
+ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+ val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
+
+ val lr = new LinearRegression()
+
+ // We use a ParamGridBuilder to construct a grid of parameters to search over.
+ // TrainValidationSplit will try all combinations of values and determine best model using
+ // the evaluator.
+ val paramGrid = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.01))
+ .addGrid(lr.fitIntercept, Array(true, false))
+ .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
+ .build()
+
+ // In this case the estimator is simply the linear regression.
+ // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ val trainValidationSplit = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(new RegressionEvaluator)
+ .setEstimatorParamMaps(paramGrid)
+
+ // 80% of the data will be used for training and the remaining 20% for validation.
+ trainValidationSplit.setTrainRatio(0.8)
+
+ // Run train validation split, and choose the best set of parameters.
+ val model = trainValidationSplit.fit(training)
+
+ // Make predictions on test data. model is the model with combination of parameters
+ // that performed best.
+ model.transform(test)
+ .select("features", "label", "prediction")
+ .show()
+
+ sc.stop()
+ }
+}