aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorEvan Sparks <evan.sparks@gmail.com>2013-08-11 10:52:55 -0700
committerEvan Sparks <evan.sparks@gmail.com>2013-08-11 10:52:55 -0700
commitff9ebfabb47e3439c7b78cb4e3c33423a1467a9a (patch)
tree8ce3ef1b12dbe1f0a8e9c64fc5a684bcab66e6bb /examples
parent95c62ca3060c89a44aa19aaab1fc9a9fff5a1196 (diff)
parenta65a6ed5140446651916aff1761a9a755194eaf4 (diff)
downloadspark-ff9ebfabb47e3439c7b78cb4e3c33423a1467a9a.tar.gz
spark-ff9ebfabb47e3439c7b78cb4e3c33423a1467a9a.tar.bz2
spark-ff9ebfabb47e3439c7b78cb4e3c33423a1467a9a.zip
Merge pull request #762 from shivaram/sgd-cleanup
Refactor SGD options into a new class.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/spark/mllib/examples/JavaLR.java85
1 files changed, 85 insertions, 0 deletions
diff --git a/examples/src/main/java/spark/mllib/examples/JavaLR.java b/examples/src/main/java/spark/mllib/examples/JavaLR.java
new file mode 100644
index 0000000000..bf4aeaf40f
--- /dev/null
+++ b/examples/src/main/java/spark/mllib/examples/JavaLR.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.examples;
+
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.Function;
+
+import spark.mllib.classification.LogisticRegressionWithSGD;
+import spark.mllib.classification.LogisticRegressionModel;
+import spark.mllib.regression.LabeledPoint;
+
+import java.util.Arrays;
+import java.util.StringTokenizer;
+
+/**
+ * Logistic regression based classification using ML Lib.
+ */
+public class JavaLR {
+
+ static class ParsePoint extends Function<String, LabeledPoint> {
+ public LabeledPoint call(String line) {
+ String[] parts = line.split(",");
+ double y = Double.parseDouble(parts[0]);
+ StringTokenizer tok = new StringTokenizer(parts[1], " ");
+ int numTokens = tok.countTokens();
+ double[] x = new double[numTokens];
+ for (int i = 0; i < numTokens; ++i) {
+ x[i] = Double.parseDouble(tok.nextToken());
+ }
+ return new LabeledPoint(y, x);
+ }
+ }
+
+ public static void printWeights(double[] a) {
+ System.out.println(Arrays.toString(a));
+ }
+
+ public static void main(String[] args) {
+ if (args.length != 4) {
+ System.err.println("Usage: JavaLR <master> <input_dir> <step_size> <niters>");
+ System.exit(1);
+ }
+
+ JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR",
+ System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+ JavaRDD<String> lines = sc.textFile(args[1]);
+ JavaRDD<LabeledPoint> points = lines.map(new ParsePoint()).cache();
+ double stepSize = Double.parseDouble(args[2]);
+ int iterations = Integer.parseInt(args[3]);
+
+ // Another way to configure LogisticRegression
+ //
+ // LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD();
+ // lr.optimizer().setNumIterations(iterations)
+ // .setStepSize(stepSize)
+ // .setMiniBatchFraction(1.0);
+ // lr.setIntercept(true);
+ // LogisticRegressionModel model = lr.train(points.rdd());
+
+ LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(),
+ iterations, stepSize);
+
+ System.out.print("Final w: ");
+ printWeights(model.weights());
+
+ System.exit(0);
+ }
+}