aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRam Sriharsha <rsriharsha@hw11853.local>2015-05-15 19:33:20 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-15 19:33:20 -0700
commitcc12a86fb049f2be1f45baf461d202ec356ccf8f (patch)
tree0a1dfff2ad9971aaaff535f238d486dff7574354 /examples
parent2c04c8a1aed34cce420b3d30d9e885daa6e03d74 (diff)
downloadspark-cc12a86fb049f2be1f45baf461d202ec356ccf8f.tar.gz
spark-cc12a86fb049f2be1f45baf461d202ec356ccf8f.tar.bz2
spark-cc12a86fb049f2be1f45baf461d202ec356ccf8f.zip
[SPARK-7575] [ML] [DOC] Example code for OneVsRest
Java and Scala examples for OneVsRest. Fixes the base classifier to be Logistic Regression and accepts the configuration parameters of the base classifier. Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #6115 from harsha2010/SPARK-7575 and squashes the following commits: 87ad3c7 [Ram Sriharsha] extra line f5d9891 [Ram Sriharsha] Merge branch 'master' into SPARK-7575 7076084 [Ram Sriharsha] cleanup dfd660c [Ram Sriharsha] cleanup 8703e4f [Ram Sriharsha] update doc cb23995 [Ram Sriharsha] fix commandline options for JavaOneVsRestExample 69e91f8 [Ram Sriharsha] cleanup 7f4e127 [Ram Sriharsha] cleanup d4c40d0 [Ram Sriharsha] Code Review fixes 461eb38 [Ram Sriharsha] cleanup e0106d9 [Ram Sriharsha] Fix typo 935cf56 [Ram Sriharsha] Try to match Java and Scala Example Commandline options 5323ff9 [Ram Sriharsha] cleanup 196a59a [Ram Sriharsha] cleanup 6adfa0c [Ram Sriharsha] Style Fix 8cfc5d5 [Ram Sriharsha] [SPARK-7575] Example code for OneVsRest
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java236
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala185
2 files changed, 421 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
new file mode 100644
index 0000000000..75063dbf80
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.commons.cli.*;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.classification.OneVsRest;
+import org.apache.spark.ml.classification.OneVsRestModel;
+import org.apache.spark.ml.util.MetadataUtils;
+import org.apache.spark.mllib.evaluation.MulticlassMetrics;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructField;
+
+/**
+ * An example runner for Multiclass to Binary Reduction with One Vs Rest.
+ * The example uses Logistic Regression as the base classifier. All parameters that
+ * can be specified on the base classifier can be passed in to the runner options.
+ * Run with
+ * <pre>
+ * bin/run-example ml.JavaOneVsRestExample [options]
+ * </pre>
+ */
+public class JavaOneVsRestExample {
+
+ private static class Params {
+ String input;
+ String testInput = null;
+ Integer maxIter = 100;
+ double tol = 1E-6;
+ boolean fitIntercept = true;
+ Double regParam = null;
+ Double elasticNetParam = null;
+ double fracTest = 0.2;
+ }
+
+ public static void main(String[] args) {
+ // parse the arguments
+ Params params = parse(args);
+ SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ // configure the base classifier
+ LogisticRegression classifier = new LogisticRegression()
+ .setMaxIter(params.maxIter)
+ .setTol(params.tol)
+ .setFitIntercept(params.fitIntercept);
+
+ if (params.regParam != null) {
+ classifier.setRegParam(params.regParam);
+ }
+ if (params.elasticNetParam != null) {
+ classifier.setElasticNetParam(params.elasticNetParam);
+ }
+
+ // instantiate the One Vs Rest Classifier
+ OneVsRest ovr = new OneVsRest().setClassifier(classifier);
+
+ String input = params.input;
+ RDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), input);
+ RDD<LabeledPoint> train;
+ RDD<LabeledPoint> test;
+
+ // compute the train/ test split: if testInput is not provided use part of input
+ String testInput = params.testInput;
+ if (testInput != null) {
+ train = inputData;
+ // compute the number of features in the training set.
+ int numFeatures = inputData.first().features().size();
+ test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures);
+ } else {
+ double f = params.fracTest;
+ RDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
+ train = tmp[0];
+ test = tmp[1];
+ }
+
+ // train the multiclass model
+ DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
+ OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());
+
+ // score the model on test data
+ DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
+ DataFrame predictions = ovrModel.transform(testDataFrame.cache())
+ .select("prediction", "label");
+
+ // obtain metrics
+ MulticlassMetrics metrics = new MulticlassMetrics(predictions);
+ StructField predictionColSchema = predictions.schema().apply("prediction");
+ Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get();
+
+ // compute the false positive rate per label
+ StringBuilder results = new StringBuilder();
+ results.append("label\tfpr\n");
+ for (int label = 0; label < numClasses; label++) {
+ results.append(label);
+ results.append("\t");
+ results.append(metrics.falsePositiveRate((double) label));
+ results.append("\n");
+ }
+
+ Matrix confusionMatrix = metrics.confusionMatrix();
+ // output the Confusion Matrix
+ System.out.println("Confusion Matrix");
+ System.out.println(confusionMatrix);
+ System.out.println();
+ System.out.println(results);
+
+ jsc.stop();
+ }
+
+ private static Params parse(String[] args) {
+ Options options = generateCommandlineOptions();
+ CommandLineParser parser = new PosixParser();
+ Params params = new Params();
+
+ try {
+ CommandLine cmd = parser.parse(options, args);
+ String value;
+ if (cmd.hasOption("input")) {
+ params.input = cmd.getOptionValue("input");
+ }
+ if (cmd.hasOption("maxIter")) {
+ value = cmd.getOptionValue("maxIter");
+ params.maxIter = Integer.parseInt(value);
+ }
+ if (cmd.hasOption("tol")) {
+ value = cmd.getOptionValue("tol");
+ params.tol = Double.parseDouble(value);
+ }
+ if (cmd.hasOption("fitIntercept")) {
+ value = cmd.getOptionValue("fitIntercept");
+ params.fitIntercept = Boolean.parseBoolean(value);
+ }
+ if (cmd.hasOption("regParam")) {
+ value = cmd.getOptionValue("regParam");
+ params.regParam = Double.parseDouble(value);
+ }
+ if (cmd.hasOption("elasticNetParam")) {
+ value = cmd.getOptionValue("elasticNetParam");
+ params.elasticNetParam = Double.parseDouble(value);
+ }
+ if (cmd.hasOption("testInput")) {
+ value = cmd.getOptionValue("testInput");
+ params.testInput = value;
+ }
+ if (cmd.hasOption("fracTest")) {
+ value = cmd.getOptionValue("fracTest");
+ params.fracTest = Double.parseDouble(value);
+ }
+
+ } catch (ParseException e) {
+ printHelpAndQuit(options);
+ }
+ return params;
+ }
+
+ private static Options generateCommandlineOptions() {
+ Option input = OptionBuilder.withArgName("input")
+ .hasArg()
+ .isRequired()
+ .withDescription("input path to labeled examples. This path must be specified")
+ .create("input");
+ Option testInput = OptionBuilder.withArgName("testInput")
+ .hasArg()
+ .withDescription("input path to test examples")
+ .create("testInput");
+ Option fracTest = OptionBuilder.withArgName("testInput")
+ .hasArg()
+ .withDescription("fraction of data to hold out for testing." +
+ " If given option testInput, this option is ignored. default: 0.2")
+ .create("fracTest");
+ Option maxIter = OptionBuilder.withArgName("maxIter")
+ .hasArg()
+ .withDescription("maximum number of iterations for Logistic Regression. default:100")
+ .create("maxIter");
+ Option tol = OptionBuilder.withArgName("tol")
+ .hasArg()
+ .withDescription("the convergence tolerance of iterations " +
+ "for Logistic Regression. default: 1E-6")
+ .create("tol");
+ Option fitIntercept = OptionBuilder.withArgName("fitIntercept")
+ .hasArg()
+ .withDescription("fit intercept for logistic regression. default true")
+ .create("fitIntercept");
+ Option regParam = OptionBuilder.withArgName( "regParam" )
+ .hasArg()
+ .withDescription("the regularization parameter for Logistic Regression.")
+ .create("regParam");
+ Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" )
+ .hasArg()
+ .withDescription("the ElasticNet mixing parameter for Logistic Regression.")
+ .create("elasticNetParam");
+
+ Options options = new Options()
+ .addOption(input)
+ .addOption(testInput)
+ .addOption(fracTest)
+ .addOption(maxIter)
+ .addOption(tol)
+ .addOption(fitIntercept)
+ .addOption(regParam)
+ .addOption(elasticNetParam);
+
+ return options;
+ }
+
+ private static void printHelpAndQuit(Options options) {
+ HelpFormatter formatter = new HelpFormatter();
+ formatter.printHelp("JavaOneVsRestExample", options);
+ System.exit(-1);
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
new file mode 100644
index 0000000000..b99d0a1246
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO}
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkContext, SparkConf}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+
+/**
+ * An example runner for Multiclass to Binary Reduction with One Vs Rest.
+ * The example uses Logistic Regression as the base classifier. All parameters that
+ * can be specified on the base classifier can be passed in to the runner options.
+ * Run with
+ * {{{
+ * ./bin/run-example ml.OneVsRestExample [options]
+ * }}}
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object OneVsRestExample {
+
+ case class Params private[ml] (
+ input: String = null,
+ testInput: Option[String] = None,
+ maxIter: Int = 100,
+ tol: Double = 1E-6,
+ fitIntercept: Boolean = true,
+ regParam: Option[Double] = None,
+ elasticNetParam: Option[Double] = None,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("OneVsRest Example") {
+ head("OneVsRest Example: multiclass to binary reduction using OneVsRest")
+ opt[String]("input")
+ .text("input path to labeled examples. This path must be specified")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text("input path to test dataset. If given, option fracTest is ignored")
+ .action((x,c) => c.copy(testInput = Some(x)))
+ opt[Int]("maxIter")
+ .text(s"maximum number of iterations for Logistic Regression." +
+ s" default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("tol")
+ .text(s"the convergence tolerance of iterations for Logistic Regression." +
+ s" default: ${defaultParams.tol}")
+ .action((x, c) => c.copy(tol = x))
+ opt[Boolean]("fitIntercept")
+ .text(s"fit intercept for Logistic Regression." +
+ s" default: ${defaultParams.fitIntercept}")
+ .action((x, c) => c.copy(fitIntercept = x))
+ opt[Double]("regParam")
+ .text(s"the regularization parameter for Logistic Regression.")
+ .action((x,c) => c.copy(regParam = Some(x)))
+ opt[Double]("elasticNetParam")
+ .text(s"the ElasticNet mixing parameter for Logistic Regression.")
+ .action((x,c) => c.copy(elasticNetParam = Some(x)))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ private def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"OneVsRestExample with $params")
+ val sc = new SparkContext(conf)
+ val inputData = MLUtils.loadLibSVMFile(sc, params.input)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // compute the train/test split: if testInput is not provided use part of input.
+ val data = params.testInput match {
+ case Some(t) => {
+ // compute the number of features in the training set.
+ val numFeatures = inputData.first().features.size
+ val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures)
+ Array[RDD[LabeledPoint]](inputData, testData)
+ }
+ case None => {
+ val f = params.fracTest
+ inputData.randomSplit(Array(1 - f, f), seed = 12345)
+ }
+ }
+ val Array(train, test) = data.map(_.toDF().cache())
+
+ // instantiate the base classifier
+ val classifier = new LogisticRegression()
+ .setMaxIter(params.maxIter)
+ .setTol(params.tol)
+ .setFitIntercept(params.fitIntercept)
+
+ // Set regParam, elasticNetParam if specified in params
+ params.regParam.foreach(classifier.setRegParam)
+ params.elasticNetParam.foreach(classifier.setElasticNetParam)
+
+ // instantiate the One Vs Rest Classifier.
+
+ val ovr = new OneVsRest()
+ ovr.setClassifier(classifier)
+
+ // train the multiclass model.
+ val (trainingDuration, ovrModel) = time(ovr.fit(train))
+
+ // score the model on test data.
+ val (predictionDuration, predictions) = time(ovrModel.transform(test))
+
+ // evaluate the model
+ val predictionsAndLabels = predictions.select("prediction", "label")
+ .map(row => (row.getDouble(0), row.getDouble(1)))
+
+ val metrics = new MulticlassMetrics(predictionsAndLabels)
+
+ val confusionMatrix = metrics.confusionMatrix
+
+ // compute the false positive rate per label
+ val predictionColSchema = predictions.schema("prediction")
+ val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get
+ val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble)))
+
+ println(s" Training Time ${trainingDuration} sec\n")
+
+ println(s" Prediction Time ${predictionDuration} sec\n")
+
+ println(s" Confusion Matrix\n ${confusionMatrix.toString}\n")
+
+ println("label\tfpr")
+
+ println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n"))
+
+ sc.stop()
+ }
+
+ private def time[R](block: => R): (Long, R) = {
+ val t0 = System.nanoTime()
+ val result = block // call-by-name
+ val t1 = System.nanoTime()
+ (NANO.toSeconds(t1 - t0), result)
+ }
+}