aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/scala
diff options
context:
space:
mode:
authorPravin Gadakh <pravingadakh177@gmail.com>2015-11-10 14:47:04 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-10 14:47:04 -0800
commit638c51d9380081b3b8182be2c2460bd53b8b0a4f (patch)
treeffece4d1f0a6832fd600a04e53638cdddf29c9aa /examples/src/main/scala
parent724cf7a38c551bf2a79b87a8158bbe1725f9f888 (diff)
downloadspark-638c51d9380081b3b8182be2c2460bd53b8b0a4f.tar.gz
spark-638c51d9380081b3b8182be2c2460bd53b8b0a4f.tar.bz2
spark-638c51d9380081b3b8182be2c2460bd53b8b0a4f.zip
[SPARK-11550][DOCS] Replace example code in mllib-optimization.md using include_example
Author: Pravin Gadakh <pravingadakh177@gmail.com> Closes #9516 from pravingadakh/SPARK-11550.
Diffstat (limited to 'examples/src/main/scala')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala90
1 files changed, 90 insertions, 0 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala
new file mode 100644
index 0000000000..61d2e7715f
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.mllib
+
+// $example on$
+import org.apache.spark.mllib.classification.LogisticRegressionModel
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater}
+import org.apache.spark.mllib.util.MLUtils
+// $example off$
+
+import org.apache.spark.{SparkConf, SparkContext}
+
+object LBFGSExample {
+
+ def main(args: Array[String]): Unit = {
+
+ val conf = new SparkConf().setAppName("LBFGSExample")
+ val sc = new SparkContext(conf)
+
+ // $example on$
+ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+ val numFeatures = data.take(1)(0).features.size
+
+ // Split data into training (60%) and test (40%).
+ val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
+
+ // Append 1 into the training data as intercept.
+ val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache()
+
+ val test = splits(1)
+
+ // Run training algorithm to build the model
+ val numCorrections = 10
+ val convergenceTol = 1e-4
+ val maxNumIterations = 20
+ val regParam = 0.1
+ val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1))
+
+ val (weightsWithIntercept, loss) = LBFGS.runLBFGS(
+ training,
+ new LogisticGradient(),
+ new SquaredL2Updater(),
+ numCorrections,
+ convergenceTol,
+ maxNumIterations,
+ regParam,
+ initialWeightsWithIntercept)
+
+ val model = new LogisticRegressionModel(
+ Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)),
+ weightsWithIntercept(weightsWithIntercept.size - 1))
+
+ // Clear the default threshold.
+ model.clearThreshold()
+
+ // Compute raw scores on the test set.
+ val scoreAndLabels = test.map { point =>
+ val score = model.predict(point.features)
+ (score, point.label)
+ }
+
+ // Get evaluation metrics.
+ val metrics = new BinaryClassificationMetrics(scoreAndLabels)
+ val auROC = metrics.areaUnderROC()
+
+ println("Loss of each step in training process")
+ loss.foreach(println)
+ println("Area under ROC = " + auROC)
+ // $example off$
+ }
+}
+// scalastyle:on println