aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/scala
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-05-11 09:53:36 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-11 09:53:36 +0200
commitad1a8466e9c10fbe8b455dba17b16973f92ebc15 (patch)
tree7da9c4df44d1774c2834b9d11cbc4f55aa3c8309 /examples/src/main/scala
parent875ef764280428acd095aec1834fee0ddad08611 (diff)
downloadspark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.tar.gz
spark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.tar.bz2
spark-ad1a8466e9c10fbe8b455dba17b16973f92ebc15.zip
[SPARK-15141][EXAMPLE][DOC] Update OneVsRest Examples
## What changes were proposed in this pull request? 1, Add python example for OneVsRest 2, remove args-parsing ## How was this patch tested? manual tests `./bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py` Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #12920 from zhengruifeng/ovr_pe.
Diffstat (limited to 'examples/src/main/scala')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala156
1 files changed, 24 insertions, 132 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
index fc73ae07ff..0b333cf629 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
@@ -18,171 +18,63 @@
// scalastyle:off println
package org.apache.spark.examples.ml
-import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO}
-
-import scopt.OptionParser
-
// $example on$
-import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}
-import org.apache.spark.ml.util.MetadataUtils
-import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.sql.DataFrame
// $example off$
import org.apache.spark.sql.SparkSession
/**
- * An example runner for Multiclass to Binary Reduction with One Vs Rest.
- * The example uses Logistic Regression as the base classifier. All parameters that
- * can be specified on the base classifier can be passed in to the runner options.
+ * An example of Multiclass to Binary Reduction with One Vs Rest,
+ * using Logistic Regression as the base classifier.
* Run with
* {{{
- * ./bin/run-example ml.OneVsRestExample [options]
- * }}}
- * For local mode, run
- * {{{
- * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g
- * [examples JAR path] [options]
+ * ./bin/run-example ml.OneVsRestExample
* }}}
- * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
-object OneVsRestExample {
-
- case class Params private[ml] (
- input: String = null,
- testInput: Option[String] = None,
- maxIter: Int = 100,
- tol: Double = 1E-6,
- fitIntercept: Boolean = true,
- regParam: Option[Double] = None,
- elasticNetParam: Option[Double] = None,
- fracTest: Double = 0.2) extends AbstractParams[Params]
+object OneVsRestExample {
def main(args: Array[String]) {
- val defaultParams = Params()
-
- val parser = new OptionParser[Params]("OneVsRest Example") {
- head("OneVsRest Example: multiclass to binary reduction using OneVsRest")
- opt[String]("input")
- .text("input path to labeled examples. This path must be specified")
- .required()
- .action((x, c) => c.copy(input = x))
- opt[Double]("fracTest")
- .text(s"fraction of data to hold out for testing. If given option testInput, " +
- s"this option is ignored. default: ${defaultParams.fracTest}")
- .action((x, c) => c.copy(fracTest = x))
- opt[String]("testInput")
- .text("input path to test dataset. If given, option fracTest is ignored")
- .action((x, c) => c.copy(testInput = Some(x)))
- opt[Int]("maxIter")
- .text(s"maximum number of iterations for Logistic Regression." +
- s" default: ${defaultParams.maxIter}")
- .action((x, c) => c.copy(maxIter = x))
- opt[Double]("tol")
- .text(s"the convergence tolerance of iterations for Logistic Regression." +
- s" default: ${defaultParams.tol}")
- .action((x, c) => c.copy(tol = x))
- opt[Boolean]("fitIntercept")
- .text(s"fit intercept for Logistic Regression." +
- s" default: ${defaultParams.fitIntercept}")
- .action((x, c) => c.copy(fitIntercept = x))
- opt[Double]("regParam")
- .text(s"the regularization parameter for Logistic Regression.")
- .action((x, c) => c.copy(regParam = Some(x)))
- opt[Double]("elasticNetParam")
- .text(s"the ElasticNet mixing parameter for Logistic Regression.")
- .action((x, c) => c.copy(elasticNetParam = Some(x)))
- checkConfig { params =>
- if (params.fracTest < 0 || params.fracTest >= 1) {
- failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
- } else {
- success
- }
- }
- }
- parser.parse(args, defaultParams).map { params =>
- run(params)
- }.getOrElse {
- sys.exit(1)
- }
- }
-
- private def run(params: Params) {
val spark = SparkSession
.builder
- .appName(s"OneVsRestExample with $params")
+ .appName(s"OneVsRestExample")
.getOrCreate()
// $example on$
- val inputData = spark.read.format("libsvm").load(params.input)
- // compute the train/test split: if testInput is not provided use part of input.
- val data = params.testInput match {
- case Some(t) =>
- // compute the number of features in the training set.
- val numFeatures = inputData.first().getAs[Vector](1).size
- val testData = spark.read.option("numFeatures", numFeatures.toString)
- .format("libsvm").load(t)
- Array[DataFrame](inputData, testData)
- case None =>
- val f = params.fracTest
- inputData.randomSplit(Array(1 - f, f), seed = 12345)
- }
- val Array(train, test) = data.map(_.cache())
+ // load data file.
+ val inputData: DataFrame = spark.read.format("libsvm")
+ .load("data/mllib/sample_multiclass_classification_data.txt")
+
+ // generate the train/test split.
+ val Array(train, test) = inputData.randomSplit(Array(0.8, 0.2))
// instantiate the base classifier
val classifier = new LogisticRegression()
- .setMaxIter(params.maxIter)
- .setTol(params.tol)
- .setFitIntercept(params.fitIntercept)
-
- // Set regParam, elasticNetParam if specified in params
- params.regParam.foreach(classifier.setRegParam)
- params.elasticNetParam.foreach(classifier.setElasticNetParam)
+ .setMaxIter(10)
+ .setTol(1E-6)
+ .setFitIntercept(true)
// instantiate the One Vs Rest Classifier.
-
- val ovr = new OneVsRest()
- ovr.setClassifier(classifier)
+ val ovr = new OneVsRest().setClassifier(classifier)
// train the multiclass model.
- val (trainingDuration, ovrModel) = time(ovr.fit(train))
+ val ovrModel = ovr.fit(train)
// score the model on test data.
- val (predictionDuration, predictions) = time(ovrModel.transform(test))
-
- // evaluate the model
- val predictionsAndLabels = predictions.select("prediction", "label")
- .rdd.map(row => (row.getDouble(0), row.getDouble(1)))
-
- val metrics = new MulticlassMetrics(predictionsAndLabels)
-
- val confusionMatrix = metrics.confusionMatrix
+ val predictions = ovrModel.transform(test)
- // compute the false positive rate per label
- val predictionColSchema = predictions.schema("prediction")
- val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get
- val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble)))
+ // obtain evaluator.
+ val evaluator = new MulticlassClassificationEvaluator()
+ .setMetricName("precision")
- println(s" Training Time ${trainingDuration} sec\n")
-
- println(s" Prediction Time ${predictionDuration} sec\n")
-
- println(s" Confusion Matrix\n ${confusionMatrix.toString}\n")
-
- println("label\tfpr")
-
- println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n"))
+ // compute the classification error on test data.
+ val precision = evaluator.evaluate(predictions)
+ println(s"Test Error : ${1 - precision}")
// $example off$
spark.stop()
}
- private def time[R](block: => R): (Long, R) = {
- val t0 = System.nanoTime()
- val result = block // call-by-name
- val t1 = System.nanoTime()
- (NANO.toSeconds(t1 - t0), result)
- }
}
// scalastyle:on println