aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-05-11 09:53:36 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-11 09:53:36 +0200
commitad1a8466e9c10fbe8b455dba17b16973f92ebc15 (patch)
tree7da9c4df44d1774c2834b9d11cbc4f55aa3c8309 /examples/src/main/java
parent875ef764280428acd095aec1834fee0ddad08611 (diff)
downloadspark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.tar.gz
spark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.tar.bz2
spark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.zip
[SPARK-15141][EXAMPLE][DOC] Update OneVsRest Examples
## What changes were proposed in this pull request? 1, Add python example for OneVsRest 2, remove args-parsing ## How was this patch tested? manual tests `./bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py` Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #12920 from zhengruifeng/ovr_pe.
Diffstat (limited to 'examples/src/main/java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java214
1 files changed, 30 insertions, 184 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
index e0cb752224..5bf455ebfe 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
@@ -17,222 +17,68 @@
package org.apache.spark.examples.ml;
-import org.apache.commons.cli.*;
-
// $example on$
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
-import org.apache.spark.ml.util.MetadataUtils;
-import org.apache.spark.mllib.evaluation.MulticlassMetrics;
-import org.apache.spark.mllib.linalg.Matrix;
-import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.types.StructField;
// $example off$
+import org.apache.spark.sql.SparkSession;
+
/**
- * An example runner for Multiclass to Binary Reduction with One Vs Rest.
- * The example uses Logistic Regression as the base classifier. All parameters that
- * can be specified on the base classifier can be passed in to the runner options.
+ * An example of Multiclass to Binary Reduction with One Vs Rest,
+ * using Logistic Regression as the base classifier.
* Run with
* <pre>
- * bin/run-example ml.JavaOneVsRestExample [options]
+ * bin/run-example ml.JavaOneVsRestExample
* </pre>
*/
public class JavaOneVsRestExample {
-
- private static class Params {
- String input;
- String testInput = null;
- Integer maxIter = 100;
- double tol = 1E-6;
- boolean fitIntercept = true;
- Double regParam = null;
- Double elasticNetParam = null;
- double fracTest = 0.2;
- }
-
public static void main(String[] args) {
- // parse the arguments
- Params params = parse(args);
SparkSession spark = SparkSession
.builder()
.appName("JavaOneVsRestExample")
.getOrCreate();
// $example on$
- // configure the base classifier
- LogisticRegression classifier = new LogisticRegression()
- .setMaxIter(params.maxIter)
- .setTol(params.tol)
- .setFitIntercept(params.fitIntercept);
+ // load data file.
+ Dataset<Row> inputData = spark.read().format("libsvm")
+ .load("data/mllib/sample_multiclass_classification_data.txt");
- if (params.regParam != null) {
- classifier.setRegParam(params.regParam);
- }
- if (params.elasticNetParam != null) {
- classifier.setElasticNetParam(params.elasticNetParam);
- }
+ // generate the train/test split.
+ Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});
+ Dataset<Row> train = tmp[0];
+ Dataset<Row> test = tmp[1];
- // instantiate the One Vs Rest Classifier
- OneVsRest ovr = new OneVsRest().setClassifier(classifier);
-
- String input = params.input;
- Dataset<Row> inputData = spark.read().format("libsvm").load(input);
- Dataset<Row> train;
- Dataset<Row> test;
+ // configure the base classifier.
+ LogisticRegression classifier = new LogisticRegression()
+ .setMaxIter(10)
+ .setTol(1E-6)
+ .setFitIntercept(true);
- // compute the train/ test split: if testInput is not provided use part of input
- String testInput = params.testInput;
- if (testInput != null) {
- train = inputData;
- // compute the number of features in the training set.
- int numFeatures = inputData.first().<Vector>getAs(1).size();
- test = spark.read().format("libsvm").option("numFeatures",
- String.valueOf(numFeatures)).load(testInput);
- } else {
- double f = params.fracTest;
- Dataset<Row>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
- train = tmp[0];
- test = tmp[1];
- }
+ // instantiate the One Vs Rest Classifier.
+ OneVsRest ovr = new OneVsRest().setClassifier(classifier);
- // train the multiclass model
- OneVsRestModel ovrModel = ovr.fit(train.cache());
+ // train the multiclass model.
+ OneVsRestModel ovrModel = ovr.fit(train);
- // score the model on test data
- Dataset<Row> predictions = ovrModel.transform(test.cache())
+ // score the model on test data.
+ Dataset<Row> predictions = ovrModel.transform(test)
.select("prediction", "label");
- // obtain metrics
- MulticlassMetrics metrics = new MulticlassMetrics(predictions);
- StructField predictionColSchema = predictions.schema().apply("prediction");
- Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get();
-
- // compute the false positive rate per label
- StringBuilder results = new StringBuilder();
- results.append("label\tfpr\n");
- for (int label = 0; label < numClasses; label++) {
- results.append(label);
- results.append("\t");
- results.append(metrics.falsePositiveRate((double) label));
- results.append("\n");
- }
+ // obtain evaluator.
+ MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
+ .setMetricName("precision");
- Matrix confusionMatrix = metrics.confusionMatrix();
- // output the Confusion Matrix
- System.out.println("Confusion Matrix");
- System.out.println(confusionMatrix);
- System.out.println();
- System.out.println(results);
+ // compute the classification error on test data.
+ double precision = evaluator.evaluate(predictions);
+ System.out.println("Test Error : " + (1 - precision));
// $example off$
spark.stop();
}
- private static Params parse(String[] args) {
- Options options = generateCommandlineOptions();
- CommandLineParser parser = new PosixParser();
- Params params = new Params();
-
- try {
- CommandLine cmd = parser.parse(options, args);
- String value;
- if (cmd.hasOption("input")) {
- params.input = cmd.getOptionValue("input");
- }
- if (cmd.hasOption("maxIter")) {
- value = cmd.getOptionValue("maxIter");
- params.maxIter = Integer.parseInt(value);
- }
- if (cmd.hasOption("tol")) {
- value = cmd.getOptionValue("tol");
- params.tol = Double.parseDouble(value);
- }
- if (cmd.hasOption("fitIntercept")) {
- value = cmd.getOptionValue("fitIntercept");
- params.fitIntercept = Boolean.parseBoolean(value);
- }
- if (cmd.hasOption("regParam")) {
- value = cmd.getOptionValue("regParam");
- params.regParam = Double.parseDouble(value);
- }
- if (cmd.hasOption("elasticNetParam")) {
- value = cmd.getOptionValue("elasticNetParam");
- params.elasticNetParam = Double.parseDouble(value);
- }
- if (cmd.hasOption("testInput")) {
- value = cmd.getOptionValue("testInput");
- params.testInput = value;
- }
- if (cmd.hasOption("fracTest")) {
- value = cmd.getOptionValue("fracTest");
- params.fracTest = Double.parseDouble(value);
- }
-
- } catch (ParseException e) {
- printHelpAndQuit(options);
- }
- return params;
- }
-
- @SuppressWarnings("static")
- private static Options generateCommandlineOptions() {
- Option input = OptionBuilder.withArgName("input")
- .hasArg()
- .isRequired()
- .withDescription("input path to labeled examples. This path must be specified")
- .create("input");
- Option testInput = OptionBuilder.withArgName("testInput")
- .hasArg()
- .withDescription("input path to test examples")
- .create("testInput");
- Option fracTest = OptionBuilder.withArgName("testInput")
- .hasArg()
- .withDescription("fraction of data to hold out for testing." +
- " If given option testInput, this option is ignored. default: 0.2")
- .create("fracTest");
- Option maxIter = OptionBuilder.withArgName("maxIter")
- .hasArg()
- .withDescription("maximum number of iterations for Logistic Regression. default:100")
- .create("maxIter");
- Option tol = OptionBuilder.withArgName("tol")
- .hasArg()
- .withDescription("the convergence tolerance of iterations " +
- "for Logistic Regression. default: 1E-6")
- .create("tol");
- Option fitIntercept = OptionBuilder.withArgName("fitIntercept")
- .hasArg()
- .withDescription("fit intercept for logistic regression. default true")
- .create("fitIntercept");
- Option regParam = OptionBuilder.withArgName( "regParam" )
- .hasArg()
- .withDescription("the regularization parameter for Logistic Regression.")
- .create("regParam");
- Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" )
- .hasArg()
- .withDescription("the ElasticNet mixing parameter for Logistic Regression.")
- .create("elasticNetParam");
-
- Options options = new Options()
- .addOption(input)
- .addOption(testInput)
- .addOption(fracTest)
- .addOption(maxIter)
- .addOption(tol)
- .addOption(fitIntercept)
- .addOption(regParam)
- .addOption(elasticNetParam);
-
- return options;
- }
-
- private static void printHelpAndQuit(Options options) {
- HelpFormatter formatter = new HelpFormatter();
- formatter.printHelp("JavaOneVsRestExample", options);
- System.exit(-1);
- }
}