aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-05-11 09:53:36 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-11 09:53:36 +0200
commitad1a8466e9c10fbe8b455dba17b16973f92ebc15 (patch)
tree7da9c4df44d1774c2834b9d11cbc4f55aa3c8309 /examples
parent875ef764280428acd095aec1834fee0ddad08611 (diff)
downloadspark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.tar.gz
spark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.tar.bz2
spark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.zip
[SPARK-15141][EXAMPLE][DOC] Update OneVsRest Examples
## What changes were proposed in this pull request? 1, Add python example for OneVsRest 2, remove args-parsing ## How was this patch tested? manual tests `./bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py` Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #12920 from zhengruifeng/ovr_pe.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java214
-rw-r--r--examples/src/main/python/ml/one_vs_rest_example.py68
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala156
3 files changed, 122 insertions, 316 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
index e0cb752224..5bf455ebfe 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
@@ -17,222 +17,68 @@
package org.apache.spark.examples.ml;
-import org.apache.commons.cli.*;
-
// $example on$
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
-import org.apache.spark.ml.util.MetadataUtils;
-import org.apache.spark.mllib.evaluation.MulticlassMetrics;
-import org.apache.spark.mllib.linalg.Matrix;
-import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.types.StructField;
// $example off$
+import org.apache.spark.sql.SparkSession;
+
/**
- * An example runner for Multiclass to Binary Reduction with One Vs Rest.
- * The example uses Logistic Regression as the base classifier. All parameters that
- * can be specified on the base classifier can be passed in to the runner options.
+ * An example of Multiclass to Binary Reduction with One Vs Rest,
+ * using Logistic Regression as the base classifier.
* Run with
* <pre>
- * bin/run-example ml.JavaOneVsRestExample [options]
+ * bin/run-example ml.JavaOneVsRestExample
* </pre>
*/
public class JavaOneVsRestExample {
-
- private static class Params {
- String input;
- String testInput = null;
- Integer maxIter = 100;
- double tol = 1E-6;
- boolean fitIntercept = true;
- Double regParam = null;
- Double elasticNetParam = null;
- double fracTest = 0.2;
- }
-
public static void main(String[] args) {
- // parse the arguments
- Params params = parse(args);
SparkSession spark = SparkSession
.builder()
.appName("JavaOneVsRestExample")
.getOrCreate();
// $example on$
- // configure the base classifier
- LogisticRegression classifier = new LogisticRegression()
- .setMaxIter(params.maxIter)
- .setTol(params.tol)
- .setFitIntercept(params.fitIntercept);
+ // load data file.
+ Dataset<Row> inputData = spark.read().format("libsvm")
+ .load("data/mllib/sample_multiclass_classification_data.txt");
- if (params.regParam != null) {
- classifier.setRegParam(params.regParam);
- }
- if (params.elasticNetParam != null) {
- classifier.setElasticNetParam(params.elasticNetParam);
- }
+ // generate the train/test split.
+ Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});
+ Dataset<Row> train = tmp[0];
+ Dataset<Row> test = tmp[1];
- // instantiate the One Vs Rest Classifier
- OneVsRest ovr = new OneVsRest().setClassifier(classifier);
-
- String input = params.input;
- Dataset<Row> inputData = spark.read().format("libsvm").load(input);
- Dataset<Row> train;
- Dataset<Row> test;
+ // configure the base classifier.
+ LogisticRegression classifier = new LogisticRegression()
+ .setMaxIter(10)
+ .setTol(1E-6)
+ .setFitIntercept(true);
- // compute the train/ test split: if testInput is not provided use part of input
- String testInput = params.testInput;
- if (testInput != null) {
- train = inputData;
- // compute the number of features in the training set.
- int numFeatures = inputData.first().<Vector>getAs(1).size();
- test = spark.read().format("libsvm").option("numFeatures",
- String.valueOf(numFeatures)).load(testInput);
- } else {
- double f = params.fracTest;
- Dataset<Row>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
- train = tmp[0];
- test = tmp[1];
- }
+ // instantiate the One Vs Rest Classifier.
+ OneVsRest ovr = new OneVsRest().setClassifier(classifier);
- // train the multiclass model
- OneVsRestModel ovrModel = ovr.fit(train.cache());
+ // train the multiclass model.
+ OneVsRestModel ovrModel = ovr.fit(train);
- // score the model on test data
- Dataset<Row> predictions = ovrModel.transform(test.cache())
+ // score the model on test data.
+ Dataset<Row> predictions = ovrModel.transform(test)
.select("prediction", "label");
- // obtain metrics
- MulticlassMetrics metrics = new MulticlassMetrics(predictions);
- StructField predictionColSchema = predictions.schema().apply("prediction");
- Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get();
-
- // compute the false positive rate per label
- StringBuilder results = new StringBuilder();
- results.append("label\tfpr\n");
- for (int label = 0; label < numClasses; label++) {
- results.append(label);
- results.append("\t");
- results.append(metrics.falsePositiveRate((double) label));
- results.append("\n");
- }
+ // obtain evaluator.
+ MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
+ .setMetricName("precision");
- Matrix confusionMatrix = metrics.confusionMatrix();
- // output the Confusion Matrix
- System.out.println("Confusion Matrix");
- System.out.println(confusionMatrix);
- System.out.println();
- System.out.println(results);
+ // compute the classification error on test data.
+ double precision = evaluator.evaluate(predictions);
+ System.out.println("Test Error : " + (1 - precision));
// $example off$
spark.stop();
}
- private static Params parse(String[] args) {
- Options options = generateCommandlineOptions();
- CommandLineParser parser = new PosixParser();
- Params params = new Params();
-
- try {
- CommandLine cmd = parser.parse(options, args);
- String value;
- if (cmd.hasOption("input")) {
- params.input = cmd.getOptionValue("input");
- }
- if (cmd.hasOption("maxIter")) {
- value = cmd.getOptionValue("maxIter");
- params.maxIter = Integer.parseInt(value);
- }
- if (cmd.hasOption("tol")) {
- value = cmd.getOptionValue("tol");
- params.tol = Double.parseDouble(value);
- }
- if (cmd.hasOption("fitIntercept")) {
- value = cmd.getOptionValue("fitIntercept");
- params.fitIntercept = Boolean.parseBoolean(value);
- }
- if (cmd.hasOption("regParam")) {
- value = cmd.getOptionValue("regParam");
- params.regParam = Double.parseDouble(value);
- }
- if (cmd.hasOption("elasticNetParam")) {
- value = cmd.getOptionValue("elasticNetParam");
- params.elasticNetParam = Double.parseDouble(value);
- }
- if (cmd.hasOption("testInput")) {
- value = cmd.getOptionValue("testInput");
- params.testInput = value;
- }
- if (cmd.hasOption("fracTest")) {
- value = cmd.getOptionValue("fracTest");
- params.fracTest = Double.parseDouble(value);
- }
-
- } catch (ParseException e) {
- printHelpAndQuit(options);
- }
- return params;
- }
-
- @SuppressWarnings("static")
- private static Options generateCommandlineOptions() {
- Option input = OptionBuilder.withArgName("input")
- .hasArg()
- .isRequired()
- .withDescription("input path to labeled examples. This path must be specified")
- .create("input");
- Option testInput = OptionBuilder.withArgName("testInput")
- .hasArg()
- .withDescription("input path to test examples")
- .create("testInput");
- Option fracTest = OptionBuilder.withArgName("testInput")
- .hasArg()
- .withDescription("fraction of data to hold out for testing." +
- " If given option testInput, this option is ignored. default: 0.2")
- .create("fracTest");
- Option maxIter = OptionBuilder.withArgName("maxIter")
- .hasArg()
- .withDescription("maximum number of iterations for Logistic Regression. default:100")
- .create("maxIter");
- Option tol = OptionBuilder.withArgName("tol")
- .hasArg()
- .withDescription("the convergence tolerance of iterations " +
- "for Logistic Regression. default: 1E-6")
- .create("tol");
- Option fitIntercept = OptionBuilder.withArgName("fitIntercept")
- .hasArg()
- .withDescription("fit intercept for logistic regression. default true")
- .create("fitIntercept");
- Option regParam = OptionBuilder.withArgName( "regParam" )
- .hasArg()
- .withDescription("the regularization parameter for Logistic Regression.")
- .create("regParam");
- Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" )
- .hasArg()
- .withDescription("the ElasticNet mixing parameter for Logistic Regression.")
- .create("elasticNetParam");
-
- Options options = new Options()
- .addOption(input)
- .addOption(testInput)
- .addOption(fracTest)
- .addOption(maxIter)
- .addOption(tol)
- .addOption(fitIntercept)
- .addOption(regParam)
- .addOption(elasticNetParam);
-
- return options;
- }
-
- private static void printHelpAndQuit(Options options) {
- HelpFormatter formatter = new HelpFormatter();
- formatter.printHelp("JavaOneVsRestExample", options);
- System.exit(-1);
- }
}
diff --git a/examples/src/main/python/ml/one_vs_rest_example.py b/examples/src/main/python/ml/one_vs_rest_example.py
new file mode 100644
index 0000000000..971156d0dd
--- /dev/null
+++ b/examples/src/main/python/ml/one_vs_rest_example.py
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+# $example on$
+from pyspark.ml.classification import LogisticRegression, OneVsRest
+from pyspark.ml.evaluation import MulticlassClassificationEvaluator
+# $example off$
+from pyspark.sql import SparkSession
+
+"""
+An example of Multiclass to Binary Reduction with One Vs Rest,
+using Logistic Regression as the base classifier.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py
+"""
+
+
+if __name__ == "__main__":
+ spark = SparkSession \
+ .builder \
+ .appName("PythonOneVsRestExample") \
+ .getOrCreate()
+
+ # $example on$
+ # load data file.
+ inputData = spark.read.format("libsvm") \
+ .load("data/mllib/sample_multiclass_classification_data.txt")
+
+ # generate the train/test split.
+ (train, test) = inputData.randomSplit([0.8, 0.2])
+
+ # instantiate the base classifier.
+ lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True)
+
+ # instantiate the One Vs Rest Classifier.
+ ovr = OneVsRest(classifier=lr)
+
+ # train the multiclass model.
+ ovrModel = ovr.fit(train)
+
+ # score the model on test data.
+ predictions = ovrModel.transform(test)
+
+ # obtain evaluator.
+ evaluator = MulticlassClassificationEvaluator(metricName="precision")
+
+ # compute the classification error on test data.
+ precision = evaluator.evaluate(predictions)
+ print("Test Error : " + str(1 - precision))
+ # $example off$
+
+ spark.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
index fc73ae07ff..0b333cf629 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
@@ -18,171 +18,63 @@
// scalastyle:off println
package org.apache.spark.examples.ml
-import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO}
-
-import scopt.OptionParser
-
// $example on$
-import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}
-import org.apache.spark.ml.util.MetadataUtils
-import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.sql.DataFrame
// $example off$
import org.apache.spark.sql.SparkSession
/**
- * An example runner for Multiclass to Binary Reduction with One Vs Rest.
- * The example uses Logistic Regression as the base classifier. All parameters that
- * can be specified on the base classifier can be passed in to the runner options.
+ * An example of Multiclass to Binary Reduction with One Vs Rest,
+ * using Logistic Regression as the base classifier.
* Run with
* {{{
- * ./bin/run-example ml.OneVsRestExample [options]
- * }}}
- * For local mode, run
- * {{{
- * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g
- * [examples JAR path] [options]
+ * ./bin/run-example ml.OneVsRestExample
* }}}
- * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
-object OneVsRestExample {
-
- case class Params private[ml] (
- input: String = null,
- testInput: Option[String] = None,
- maxIter: Int = 100,
- tol: Double = 1E-6,
- fitIntercept: Boolean = true,
- regParam: Option[Double] = None,
- elasticNetParam: Option[Double] = None,
- fracTest: Double = 0.2) extends AbstractParams[Params]
+object OneVsRestExample {
def main(args: Array[String]) {
- val defaultParams = Params()
-
- val parser = new OptionParser[Params]("OneVsRest Example") {
- head("OneVsRest Example: multiclass to binary reduction using OneVsRest")
- opt[String]("input")
- .text("input path to labeled examples. This path must be specified")
- .required()
- .action((x, c) => c.copy(input = x))
- opt[Double]("fracTest")
- .text(s"fraction of data to hold out for testing. If given option testInput, " +
- s"this option is ignored. default: ${defaultParams.fracTest}")
- .action((x, c) => c.copy(fracTest = x))
- opt[String]("testInput")
- .text("input path to test dataset. If given, option fracTest is ignored")
- .action((x, c) => c.copy(testInput = Some(x)))
- opt[Int]("maxIter")
- .text(s"maximum number of iterations for Logistic Regression." +
- s" default: ${defaultParams.maxIter}")
- .action((x, c) => c.copy(maxIter = x))
- opt[Double]("tol")
- .text(s"the convergence tolerance of iterations for Logistic Regression." +
- s" default: ${defaultParams.tol}")
- .action((x, c) => c.copy(tol = x))
- opt[Boolean]("fitIntercept")
- .text(s"fit intercept for Logistic Regression." +
- s" default: ${defaultParams.fitIntercept}")
- .action((x, c) => c.copy(fitIntercept = x))
- opt[Double]("regParam")
- .text(s"the regularization parameter for Logistic Regression.")
- .action((x, c) => c.copy(regParam = Some(x)))
- opt[Double]("elasticNetParam")
- .text(s"the ElasticNet mixing parameter for Logistic Regression.")
- .action((x, c) => c.copy(elasticNetParam = Some(x)))
- checkConfig { params =>
- if (params.fracTest < 0 || params.fracTest >= 1) {
- failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
- } else {
- success
- }
- }
- }
- parser.parse(args, defaultParams).map { params =>
- run(params)
- }.getOrElse {
- sys.exit(1)
- }
- }
-
- private def run(params: Params) {
val spark = SparkSession
.builder
- .appName(s"OneVsRestExample with $params")
+ .appName(s"OneVsRestExample")
.getOrCreate()
// $example on$
- val inputData = spark.read.format("libsvm").load(params.input)
- // compute the train/test split: if testInput is not provided use part of input.
- val data = params.testInput match {
- case Some(t) =>
- // compute the number of features in the training set.
- val numFeatures = inputData.first().getAs[Vector](1).size
- val testData = spark.read.option("numFeatures", numFeatures.toString)
- .format("libsvm").load(t)
- Array[DataFrame](inputData, testData)
- case None =>
- val f = params.fracTest
- inputData.randomSplit(Array(1 - f, f), seed = 12345)
- }
- val Array(train, test) = data.map(_.cache())
+ // load data file.
+ val inputData: DataFrame = spark.read.format("libsvm")
+ .load("data/mllib/sample_multiclass_classification_data.txt")
+
+ // generate the train/test split.
+ val Array(train, test) = inputData.randomSplit(Array(0.8, 0.2))
// instantiate the base classifier
val classifier = new LogisticRegression()
- .setMaxIter(params.maxIter)
- .setTol(params.tol)
- .setFitIntercept(params.fitIntercept)
-
- // Set regParam, elasticNetParam if specified in params
- params.regParam.foreach(classifier.setRegParam)
- params.elasticNetParam.foreach(classifier.setElasticNetParam)
+ .setMaxIter(10)
+ .setTol(1E-6)
+ .setFitIntercept(true)
// instantiate the One Vs Rest Classifier.
-
- val ovr = new OneVsRest()
- ovr.setClassifier(classifier)
+ val ovr = new OneVsRest().setClassifier(classifier)
// train the multiclass model.
- val (trainingDuration, ovrModel) = time(ovr.fit(train))
+ val ovrModel = ovr.fit(train)
// score the model on test data.
- val (predictionDuration, predictions) = time(ovrModel.transform(test))
-
- // evaluate the model
- val predictionsAndLabels = predictions.select("prediction", "label")
- .rdd.map(row => (row.getDouble(0), row.getDouble(1)))
-
- val metrics = new MulticlassMetrics(predictionsAndLabels)
-
- val confusionMatrix = metrics.confusionMatrix
+ val predictions = ovrModel.transform(test)
- // compute the false positive rate per label
- val predictionColSchema = predictions.schema("prediction")
- val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get
- val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble)))
+ // obtain evaluator.
+ val evaluator = new MulticlassClassificationEvaluator()
+ .setMetricName("precision")
- println(s" Training Time ${trainingDuration} sec\n")
-
- println(s" Prediction Time ${predictionDuration} sec\n")
-
- println(s" Confusion Matrix\n ${confusionMatrix.toString}\n")
-
- println("label\tfpr")
-
- println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n"))
+ // compute the classification error on test data.
+ val precision = evaluator.evaluate(predictions)
+ println(s"Test Error : ${1 - precision}")
// $example off$
spark.stop()
}
- private def time[R](block: => R): (Long, R) = {
- val t0 = System.nanoTime()
- val result = block // call-by-name
- val t1 = System.nanoTime()
- (NANO.toSeconds(t1 - t0), result)
- }
}
// scalastyle:on println