From ad1a8466e9c10fbe8b455dba17b16973f92ebc15 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 11 May 2016 09:53:36 +0200 Subject: [SPARK-15141][EXAMPLE][DOC] Update OneVsRest Examples ## What changes were proposed in this pull request? 1, Add python example for OneVsRest 2, remove args-parsing ## How was this patch tested? manual tests `./bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py` Author: Zheng RuiFeng Closes #12920 from zhengruifeng/ovr_pe. --- .../spark/examples/ml/JavaOneVsRestExample.java | 214 +++------------------ examples/src/main/python/ml/one_vs_rest_example.py | 68 +++++++ .../spark/examples/ml/OneVsRestExample.scala | 156 +++------------ 3 files changed, 122 insertions(+), 316 deletions(-) create mode 100644 examples/src/main/python/ml/one_vs_rest_example.py (limited to 'examples') diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java index e0cb752224..5bf455ebfe 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -17,222 +17,68 @@ package org.apache.spark.examples.ml; -import org.apache.commons.cli.*; - // $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.OneVsRest; import org.apache.spark.ml.classification.OneVsRestModel; -import org.apache.spark.ml.util.MetadataUtils; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.types.StructField; // $example off$ +import org.apache.spark.sql.SparkSession; + /** - * An example runner for Multiclass to Binary Reduction with One Vs Rest. - * The example uses Logistic Regression as the base classifier. All parameters that - * can be specified on the base classifier can be passed in to the runner options. + * An example of Multiclass to Binary Reduction with One Vs Rest, + * using Logistic Regression as the base classifier. * Run with *
- * bin/run-example ml.JavaOneVsRestExample [options]
+ * bin/run-example ml.JavaOneVsRestExample
  * 
*/ public class JavaOneVsRestExample { - - private static class Params { - String input; - String testInput = null; - Integer maxIter = 100; - double tol = 1E-6; - boolean fitIntercept = true; - Double regParam = null; - Double elasticNetParam = null; - double fracTest = 0.2; - } - public static void main(String[] args) { - // parse the arguments - Params params = parse(args); SparkSession spark = SparkSession .builder() .appName("JavaOneVsRestExample") .getOrCreate(); // $example on$ - // configure the base classifier - LogisticRegression classifier = new LogisticRegression() - .setMaxIter(params.maxIter) - .setTol(params.tol) - .setFitIntercept(params.fitIntercept); + // load data file. + Dataset inputData = spark.read().format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt"); - if (params.regParam != null) { - classifier.setRegParam(params.regParam); - } - if (params.elasticNetParam != null) { - classifier.setElasticNetParam(params.elasticNetParam); - } + // generate the train/test split. + Dataset[] tmp = inputData.randomSplit(new double[]{0.8, 0.2}); + Dataset train = tmp[0]; + Dataset test = tmp[1]; - // instantiate the One Vs Rest Classifier - OneVsRest ovr = new OneVsRest().setClassifier(classifier); - - String input = params.input; - Dataset inputData = spark.read().format("libsvm").load(input); - Dataset train; - Dataset test; + // configure the base classifier. + LogisticRegression classifier = new LogisticRegression() + .setMaxIter(10) + .setTol(1E-6) + .setFitIntercept(true); - // compute the train/ test split: if testInput is not provided use part of input - String testInput = params.testInput; - if (testInput != null) { - train = inputData; - // compute the number of features in the training set. - int numFeatures = inputData.first().getAs(1).size(); - test = spark.read().format("libsvm").option("numFeatures", - String.valueOf(numFeatures)).load(testInput); - } else { - double f = params.fracTest; - Dataset[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); - train = tmp[0]; - test = tmp[1]; - } + // instantiate the One Vs Rest Classifier. + OneVsRest ovr = new OneVsRest().setClassifier(classifier); - // train the multiclass model - OneVsRestModel ovrModel = ovr.fit(train.cache()); + // train the multiclass model. + OneVsRestModel ovrModel = ovr.fit(train); - // score the model on test data - Dataset predictions = ovrModel.transform(test.cache()) + // score the model on test data. + Dataset predictions = ovrModel.transform(test) .select("prediction", "label"); - // obtain metrics - MulticlassMetrics metrics = new MulticlassMetrics(predictions); - StructField predictionColSchema = predictions.schema().apply("prediction"); - Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get(); - - // compute the false positive rate per label - StringBuilder results = new StringBuilder(); - results.append("label\tfpr\n"); - for (int label = 0; label < numClasses; label++) { - results.append(label); - results.append("\t"); - results.append(metrics.falsePositiveRate((double) label)); - results.append("\n"); - } + // obtain evaluator. + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); - Matrix confusionMatrix = metrics.confusionMatrix(); - // output the Confusion Matrix - System.out.println("Confusion Matrix"); - System.out.println(confusionMatrix); - System.out.println(); - System.out.println(results); + // compute the classification error on test data. + double precision = evaluator.evaluate(predictions); + System.out.println("Test Error : " + (1 - precision)); // $example off$ spark.stop(); } - private static Params parse(String[] args) { - Options options = generateCommandlineOptions(); - CommandLineParser parser = new PosixParser(); - Params params = new Params(); - - try { - CommandLine cmd = parser.parse(options, args); - String value; - if (cmd.hasOption("input")) { - params.input = cmd.getOptionValue("input"); - } - if (cmd.hasOption("maxIter")) { - value = cmd.getOptionValue("maxIter"); - params.maxIter = Integer.parseInt(value); - } - if (cmd.hasOption("tol")) { - value = cmd.getOptionValue("tol"); - params.tol = Double.parseDouble(value); - } - if (cmd.hasOption("fitIntercept")) { - value = cmd.getOptionValue("fitIntercept"); - params.fitIntercept = Boolean.parseBoolean(value); - } - if (cmd.hasOption("regParam")) { - value = cmd.getOptionValue("regParam"); - params.regParam = Double.parseDouble(value); - } - if (cmd.hasOption("elasticNetParam")) { - value = cmd.getOptionValue("elasticNetParam"); - params.elasticNetParam = Double.parseDouble(value); - } - if (cmd.hasOption("testInput")) { - value = cmd.getOptionValue("testInput"); - params.testInput = value; - } - if (cmd.hasOption("fracTest")) { - value = cmd.getOptionValue("fracTest"); - params.fracTest = Double.parseDouble(value); - } - - } catch (ParseException e) { - printHelpAndQuit(options); - } - return params; - } - - @SuppressWarnings("static") - private static Options generateCommandlineOptions() { - Option input = OptionBuilder.withArgName("input") - .hasArg() - .isRequired() - .withDescription("input path to labeled examples. This path must be specified") - .create("input"); - Option testInput = OptionBuilder.withArgName("testInput") - .hasArg() - .withDescription("input path to test examples") - .create("testInput"); - Option fracTest = OptionBuilder.withArgName("testInput") - .hasArg() - .withDescription("fraction of data to hold out for testing." + - " If given option testInput, this option is ignored. default: 0.2") - .create("fracTest"); - Option maxIter = OptionBuilder.withArgName("maxIter") - .hasArg() - .withDescription("maximum number of iterations for Logistic Regression. default:100") - .create("maxIter"); - Option tol = OptionBuilder.withArgName("tol") - .hasArg() - .withDescription("the convergence tolerance of iterations " + - "for Logistic Regression. default: 1E-6") - .create("tol"); - Option fitIntercept = OptionBuilder.withArgName("fitIntercept") - .hasArg() - .withDescription("fit intercept for logistic regression. default true") - .create("fitIntercept"); - Option regParam = OptionBuilder.withArgName( "regParam" ) - .hasArg() - .withDescription("the regularization parameter for Logistic Regression.") - .create("regParam"); - Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" ) - .hasArg() - .withDescription("the ElasticNet mixing parameter for Logistic Regression.") - .create("elasticNetParam"); - - Options options = new Options() - .addOption(input) - .addOption(testInput) - .addOption(fracTest) - .addOption(maxIter) - .addOption(tol) - .addOption(fitIntercept) - .addOption(regParam) - .addOption(elasticNetParam); - - return options; - } - - private static void printHelpAndQuit(Options options) { - HelpFormatter formatter = new HelpFormatter(); - formatter.printHelp("JavaOneVsRestExample", options); - System.exit(-1); - } } diff --git a/examples/src/main/python/ml/one_vs_rest_example.py b/examples/src/main/python/ml/one_vs_rest_example.py new file mode 100644 index 0000000000..971156d0dd --- /dev/null +++ b/examples/src/main/python/ml/one_vs_rest_example.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.classification import LogisticRegression, OneVsRest +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ +from pyspark.sql import SparkSession + +""" +An example of Multiclass to Binary Reduction with One Vs Rest, +using Logistic Regression as the base classifier. +Run with: + bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py +""" + + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("PythonOneVsRestExample") \ + .getOrCreate() + + # $example on$ + # load data file. + inputData = spark.read.format("libsvm") \ + .load("data/mllib/sample_multiclass_classification_data.txt") + + # generate the train/test split. + (train, test) = inputData.randomSplit([0.8, 0.2]) + + # instantiate the base classifier. + lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True) + + # instantiate the One Vs Rest Classifier. + ovr = OneVsRest(classifier=lr) + + # train the multiclass model. + ovrModel = ovr.fit(train) + + # score the model on test data. + predictions = ovrModel.transform(test) + + # obtain evaluator. + evaluator = MulticlassClassificationEvaluator(metricName="precision") + + # compute the classification error on test data. + precision = evaluator.evaluate(predictions) + print("Test Error : " + str(1 - precision)) + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index fc73ae07ff..0b333cf629 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -18,171 +18,63 @@ // scalastyle:off println package org.apache.spark.examples.ml -import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} - -import scopt.OptionParser - // $example on$ -import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} -import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.sql.DataFrame // $example off$ import org.apache.spark.sql.SparkSession /** - * An example runner for Multiclass to Binary Reduction with One Vs Rest. - * The example uses Logistic Regression as the base classifier. All parameters that - * can be specified on the base classifier can be passed in to the runner options. + * An example of Multiclass to Binary Reduction with One Vs Rest, + * using Logistic Regression as the base classifier. * Run with * {{{ - * ./bin/run-example ml.OneVsRestExample [options] - * }}} - * For local mode, run - * {{{ - * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g - * [examples JAR path] [options] + * ./bin/run-example ml.OneVsRestExample * }}} - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object OneVsRestExample { - - case class Params private[ml] ( - input: String = null, - testInput: Option[String] = None, - maxIter: Int = 100, - tol: Double = 1E-6, - fitIntercept: Boolean = true, - regParam: Option[Double] = None, - elasticNetParam: Option[Double] = None, - fracTest: Double = 0.2) extends AbstractParams[Params] +object OneVsRestExample { def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("OneVsRest Example") { - head("OneVsRest Example: multiclass to binary reduction using OneVsRest") - opt[String]("input") - .text("input path to labeled examples. This path must be specified") - .required() - .action((x, c) => c.copy(input = x)) - opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing. If given option testInput, " + - s"this option is ignored. default: ${defaultParams.fracTest}") - .action((x, c) => c.copy(fracTest = x)) - opt[String]("testInput") - .text("input path to test dataset. If given, option fracTest is ignored") - .action((x, c) => c.copy(testInput = Some(x))) - opt[Int]("maxIter") - .text(s"maximum number of iterations for Logistic Regression." + - s" default: ${defaultParams.maxIter}") - .action((x, c) => c.copy(maxIter = x)) - opt[Double]("tol") - .text(s"the convergence tolerance of iterations for Logistic Regression." + - s" default: ${defaultParams.tol}") - .action((x, c) => c.copy(tol = x)) - opt[Boolean]("fitIntercept") - .text(s"fit intercept for Logistic Regression." + - s" default: ${defaultParams.fitIntercept}") - .action((x, c) => c.copy(fitIntercept = x)) - opt[Double]("regParam") - .text(s"the regularization parameter for Logistic Regression.") - .action((x, c) => c.copy(regParam = Some(x))) - opt[Double]("elasticNetParam") - .text(s"the ElasticNet mixing parameter for Logistic Regression.") - .action((x, c) => c.copy(elasticNetParam = Some(x))) - checkConfig { params => - if (params.fracTest < 0 || params.fracTest >= 1) { - failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") - } else { - success - } - } - } - parser.parse(args, defaultParams).map { params => - run(params) - }.getOrElse { - sys.exit(1) - } - } - - private def run(params: Params) { val spark = SparkSession .builder - .appName(s"OneVsRestExample with $params") + .appName(s"OneVsRestExample") .getOrCreate() // $example on$ - val inputData = spark.read.format("libsvm").load(params.input) - // compute the train/test split: if testInput is not provided use part of input. - val data = params.testInput match { - case Some(t) => - // compute the number of features in the training set. - val numFeatures = inputData.first().getAs[Vector](1).size - val testData = spark.read.option("numFeatures", numFeatures.toString) - .format("libsvm").load(t) - Array[DataFrame](inputData, testData) - case None => - val f = params.fracTest - inputData.randomSplit(Array(1 - f, f), seed = 12345) - } - val Array(train, test) = data.map(_.cache()) + // load data file. + val inputData: DataFrame = spark.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") + + // generate the train/test split. + val Array(train, test) = inputData.randomSplit(Array(0.8, 0.2)) // instantiate the base classifier val classifier = new LogisticRegression() - .setMaxIter(params.maxIter) - .setTol(params.tol) - .setFitIntercept(params.fitIntercept) - - // Set regParam, elasticNetParam if specified in params - params.regParam.foreach(classifier.setRegParam) - params.elasticNetParam.foreach(classifier.setElasticNetParam) + .setMaxIter(10) + .setTol(1E-6) + .setFitIntercept(true) // instantiate the One Vs Rest Classifier. - - val ovr = new OneVsRest() - ovr.setClassifier(classifier) + val ovr = new OneVsRest().setClassifier(classifier) // train the multiclass model. - val (trainingDuration, ovrModel) = time(ovr.fit(train)) + val ovrModel = ovr.fit(train) // score the model on test data. - val (predictionDuration, predictions) = time(ovrModel.transform(test)) - - // evaluate the model - val predictionsAndLabels = predictions.select("prediction", "label") - .rdd.map(row => (row.getDouble(0), row.getDouble(1))) - - val metrics = new MulticlassMetrics(predictionsAndLabels) - - val confusionMatrix = metrics.confusionMatrix + val predictions = ovrModel.transform(test) - // compute the false positive rate per label - val predictionColSchema = predictions.schema("prediction") - val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get - val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble))) + // obtain evaluator. + val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision") - println(s" Training Time ${trainingDuration} sec\n") - - println(s" Prediction Time ${predictionDuration} sec\n") - - println(s" Confusion Matrix\n ${confusionMatrix.toString}\n") - - println("label\tfpr") - - println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + // compute the classification error on test data. + val precision = evaluator.evaluate(predictions) + println(s"Test Error : ${1 - precision}") // $example off$ spark.stop() } - private def time[R](block: => R): (Long, R) = { - val t0 = System.nanoTime() - val result = block // call-by-name - val t1 = System.nanoTime() - (NANO.toSeconds(t1 - t0), result) - } } // scalastyle:on println -- cgit v1.2.3