diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2016-05-11 09:53:36 +0200 |
---|---|---|
committer | Nick Pentreath <nickp@za.ibm.com> | 2016-05-11 09:53:36 +0200 |
commit | ad1a8466e9c10fbe8b455dba17b16973f92ebc15 (patch) | |
tree | 7da9c4df44d1774c2834b9d11cbc4f55aa3c8309 /examples/src/main/java | |
parent | 875ef764280428acd095aec1834fee0ddad08611 (diff) | |
download | spark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.tar.gz spark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.tar.bz2 spark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.zip |
[SPARK-15141][EXAMPLE][DOC] Update OneVsRest Examples
## What changes were proposed in this pull request?
1, Add python example for OneVsRest
2, remove args-parsing
## How was this patch tested?
manual tests
`./bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py`
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Closes #12920 from zhengruifeng/ovr_pe.
Diffstat (limited to 'examples/src/main/java')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java | 214 |
1 files changed, 30 insertions, 184 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java index e0cb752224..5bf455ebfe 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -17,222 +17,68 @@ package org.apache.spark.examples.ml; -import org.apache.commons.cli.*; - // $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.OneVsRest; import org.apache.spark.ml.classification.OneVsRestModel; -import org.apache.spark.ml.util.MetadataUtils; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.types.StructField; // $example off$ +import org.apache.spark.sql.SparkSession; + /** - * An example runner for Multiclass to Binary Reduction with One Vs Rest. - * The example uses Logistic Regression as the base classifier. All parameters that - * can be specified on the base classifier can be passed in to the runner options. + * An example of Multiclass to Binary Reduction with One Vs Rest, + * using Logistic Regression as the base classifier. * Run with * <pre> - * bin/run-example ml.JavaOneVsRestExample [options] + * bin/run-example ml.JavaOneVsRestExample * </pre> */ public class JavaOneVsRestExample { - - private static class Params { - String input; - String testInput = null; - Integer maxIter = 100; - double tol = 1E-6; - boolean fitIntercept = true; - Double regParam = null; - Double elasticNetParam = null; - double fracTest = 0.2; - } - public static void main(String[] args) { - // parse the arguments - Params params = parse(args); SparkSession spark = SparkSession .builder() .appName("JavaOneVsRestExample") .getOrCreate(); // $example on$ - // configure the base classifier - LogisticRegression classifier = new LogisticRegression() - .setMaxIter(params.maxIter) - .setTol(params.tol) - .setFitIntercept(params.fitIntercept); + // load data file. + Dataset<Row> inputData = spark.read().format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt"); - if (params.regParam != null) { - classifier.setRegParam(params.regParam); - } - if (params.elasticNetParam != null) { - classifier.setElasticNetParam(params.elasticNetParam); - } + // generate the train/test split. + Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2}); + Dataset<Row> train = tmp[0]; + Dataset<Row> test = tmp[1]; - // instantiate the One Vs Rest Classifier - OneVsRest ovr = new OneVsRest().setClassifier(classifier); - - String input = params.input; - Dataset<Row> inputData = spark.read().format("libsvm").load(input); - Dataset<Row> train; - Dataset<Row> test; + // configure the base classifier. + LogisticRegression classifier = new LogisticRegression() + .setMaxIter(10) + .setTol(1E-6) + .setFitIntercept(true); - // compute the train/ test split: if testInput is not provided use part of input - String testInput = params.testInput; - if (testInput != null) { - train = inputData; - // compute the number of features in the training set. - int numFeatures = inputData.first().<Vector>getAs(1).size(); - test = spark.read().format("libsvm").option("numFeatures", - String.valueOf(numFeatures)).load(testInput); - } else { - double f = params.fracTest; - Dataset<Row>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); - train = tmp[0]; - test = tmp[1]; - } + // instantiate the One Vs Rest Classifier. + OneVsRest ovr = new OneVsRest().setClassifier(classifier); - // train the multiclass model - OneVsRestModel ovrModel = ovr.fit(train.cache()); + // train the multiclass model. + OneVsRestModel ovrModel = ovr.fit(train); - // score the model on test data - Dataset<Row> predictions = ovrModel.transform(test.cache()) + // score the model on test data. + Dataset<Row> predictions = ovrModel.transform(test) .select("prediction", "label"); - // obtain metrics - MulticlassMetrics metrics = new MulticlassMetrics(predictions); - StructField predictionColSchema = predictions.schema().apply("prediction"); - Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get(); - - // compute the false positive rate per label - StringBuilder results = new StringBuilder(); - results.append("label\tfpr\n"); - for (int label = 0; label < numClasses; label++) { - results.append(label); - results.append("\t"); - results.append(metrics.falsePositiveRate((double) label)); - results.append("\n"); - } + // obtain evaluator. + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); - Matrix confusionMatrix = metrics.confusionMatrix(); - // output the Confusion Matrix - System.out.println("Confusion Matrix"); - System.out.println(confusionMatrix); - System.out.println(); - System.out.println(results); + // compute the classification error on test data. + double precision = evaluator.evaluate(predictions); + System.out.println("Test Error : " + (1 - precision)); // $example off$ spark.stop(); } - private static Params parse(String[] args) { - Options options = generateCommandlineOptions(); - CommandLineParser parser = new PosixParser(); - Params params = new Params(); - - try { - CommandLine cmd = parser.parse(options, args); - String value; - if (cmd.hasOption("input")) { - params.input = cmd.getOptionValue("input"); - } - if (cmd.hasOption("maxIter")) { - value = cmd.getOptionValue("maxIter"); - params.maxIter = Integer.parseInt(value); - } - if (cmd.hasOption("tol")) { - value = cmd.getOptionValue("tol"); - params.tol = Double.parseDouble(value); - } - if (cmd.hasOption("fitIntercept")) { - value = cmd.getOptionValue("fitIntercept"); - params.fitIntercept = Boolean.parseBoolean(value); - } - if (cmd.hasOption("regParam")) { - value = cmd.getOptionValue("regParam"); - params.regParam = Double.parseDouble(value); - } - if (cmd.hasOption("elasticNetParam")) { - value = cmd.getOptionValue("elasticNetParam"); - params.elasticNetParam = Double.parseDouble(value); - } - if (cmd.hasOption("testInput")) { - value = cmd.getOptionValue("testInput"); - params.testInput = value; - } - if (cmd.hasOption("fracTest")) { - value = cmd.getOptionValue("fracTest"); - params.fracTest = Double.parseDouble(value); - } - - } catch (ParseException e) { - printHelpAndQuit(options); - } - return params; - } - - @SuppressWarnings("static") - private static Options generateCommandlineOptions() { - Option input = OptionBuilder.withArgName("input") - .hasArg() - .isRequired() - .withDescription("input path to labeled examples. This path must be specified") - .create("input"); - Option testInput = OptionBuilder.withArgName("testInput") - .hasArg() - .withDescription("input path to test examples") - .create("testInput"); - Option fracTest = OptionBuilder.withArgName("testInput") - .hasArg() - .withDescription("fraction of data to hold out for testing." + - " If given option testInput, this option is ignored. default: 0.2") - .create("fracTest"); - Option maxIter = OptionBuilder.withArgName("maxIter") - .hasArg() - .withDescription("maximum number of iterations for Logistic Regression. default:100") - .create("maxIter"); - Option tol = OptionBuilder.withArgName("tol") - .hasArg() - .withDescription("the convergence tolerance of iterations " + - "for Logistic Regression. default: 1E-6") - .create("tol"); - Option fitIntercept = OptionBuilder.withArgName("fitIntercept") - .hasArg() - .withDescription("fit intercept for logistic regression. default true") - .create("fitIntercept"); - Option regParam = OptionBuilder.withArgName( "regParam" ) - .hasArg() - .withDescription("the regularization parameter for Logistic Regression.") - .create("regParam"); - Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" ) - .hasArg() - .withDescription("the ElasticNet mixing parameter for Logistic Regression.") - .create("elasticNetParam"); - - Options options = new Options() - .addOption(input) - .addOption(testInput) - .addOption(fracTest) - .addOption(maxIter) - .addOption(tol) - .addOption(fitIntercept) - .addOption(regParam) - .addOption(elasticNetParam); - - return options; - } - - private static void printHelpAndQuit(Options options) { - HelpFormatter formatter = new HelpFormatter(); - formatter.printHelp("JavaOneVsRestExample", options); - System.exit(-1); - } } |