aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java87
1 files changed, 0 insertions, 87 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
deleted file mode 100644
index 09bbc39c01..0000000000
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.ml;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.evaluation.RegressionEvaluator;
-import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.ml.regression.LinearRegression;
-import org.apache.spark.ml.tuning.*;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-
-/**
- * A simple example demonstrating model selection using TrainValidationSplit.
- *
- * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample}
- * using linear regression.
- *
- * Run with
- * {{{
- * bin/run-example ml.JavaTrainValidationSplitExample
- * }}}
- */
-public class JavaTrainValidationSplitExample {
-
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
- JavaSparkContext jsc = new JavaSparkContext(conf);
- SQLContext jsql = new SQLContext(jsc);
-
- Dataset<Row> data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
-
- // Prepare training and test data.
- Dataset<Row>[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
- Dataset<Row> training = splits[0];
- Dataset<Row> test = splits[1];
-
- LinearRegression lr = new LinearRegression();
-
- // We use a ParamGridBuilder to construct a grid of parameters to search over.
- // TrainValidationSplit will try all combinations of values and determine best model using
- // the evaluator.
- ParamMap[] paramGrid = new ParamGridBuilder()
- .addGrid(lr.regParam(), new double[] {0.1, 0.01})
- .addGrid(lr.fitIntercept())
- .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
- .build();
-
- // In this case the estimator is simply the linear regression.
- // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
- TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
- .setEstimator(lr)
- .setEvaluator(new RegressionEvaluator())
- .setEstimatorParamMaps(paramGrid);
-
- // 80% of the data will be used for training and the remaining 20% for validation.
- trainValidationSplit.setTrainRatio(0.8);
-
- // Run train validation split, and choose the best set of parameters.
- TrainValidationSplitModel model = trainValidationSplit.fit(training);
-
- // Make predictions on test data. model is the model with combination of parameters
- // that performed best.
- model.transform(test)
- .select("features", "label", "prediction")
- .show();
-
- jsc.stop();
- }
-}