aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorfreeman <the.freeman.lab@gmail.com>2015-02-02 22:42:15 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-02 22:42:15 -0800
commiteb0da6c4bd55aaab972c53eb934e68257b8994e5 (patch)
tree9c02a150c8b86cbda58790a636509538c751e760 /examples
parentc306555f491e45ef870f58938af397f9ec5f166a (diff)
downloadspark-eb0da6c4bd55aaab972c53eb934e68257b8994e5.tar.gz
spark-eb0da6c4bd55aaab972c53eb934e68257b8994e5.tar.bz2
spark-eb0da6c4bd55aaab972c53eb934e68257b8994e5.zip
[SPARK-4979][MLLIB] Streaming logisitic regression
This adds support for streaming logistic regression with stochastic gradient descent, in the same manner as the existing implementation of streaming linear regression. It is a relatively simple addition because most of the work is already done by the abstract class `StreamingLinearAlgorithm` and existing algorithms and models from MLlib. The PR includes - Streaming Logistic Regression algorithm - Unit tests for accuracy, streaming convergence, and streaming prediction - An example use cc mengxr tdas Author: freeman <the.freeman.lab@gmail.com> Closes #4306 from freeman-lab/streaming-logisitic-regression and squashes the following commits: 5c2c70b [freeman] Use Option on model 5cca2bc [freeman] Merge remote-tracking branch 'upstream/master' into streaming-logisitic-regression 275f8bd [freeman] Make private to mllib 3926e4e [freeman] Line formatting 5ee8694 [freeman] Experimental tag for docs 2fc68ac [freeman] Fix example formatting 85320b1 [freeman] Fixed line length d88f717 [freeman] Remove stray comment 59d7ecb [freeman] Add streaming logistic regression e78fe28 [freeman] Add streaming logistic regression example 321cc66 [freeman] Set private and protected within mllib
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala3
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala73
2 files changed, 74 insertions, 2 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala
index c5bd5b0b17..1a95048bbf 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala
@@ -35,8 +35,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
*
* To run on your local machine using the two directories `trainingDir` and `testDir`,
* with updates every 5 seconds, and 2 features per data point, call:
- * $ bin/run-example \
- * org.apache.spark.examples.mllib.StreamingLinearRegression trainingDir testDir 5 2
+ * $ bin/run-example mllib.StreamingLinearRegression trainingDir testDir 5 2
*
* As you add text files to `trainingDir` the model will continuously update.
* Anytime you add text files to `testDir`, you'll see predictions from the current model.
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala
new file mode 100644
index 0000000000..e1998099c2
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.{Seconds, StreamingContext}
+
+/**
+ * Train a logistic regression model on one stream of data and make predictions
+ * on another stream, where the data streams arrive as text files
+ * into two different directories.
+ *
+ * The rows of the text files must be labeled data points in the form
+ * `(y,[x1,x2,x3,...,xn])`
+ * Where n is the number of features, y is a binary label, and
+ * n must be the same for train and test.
+ *
+ * Usage: StreamingLogisticRegression <trainingDir> <testDir> <batchDuration> <numFeatures>
+ *
+ * To run on your local machine using the two directories `trainingDir` and `testDir`,
+ * with updates every 5 seconds, and 2 features per data point, call:
+ * $ bin/run-example mllib.StreamingLogisticRegression trainingDir testDir 5 2
+ *
+ * As you add text files to `trainingDir` the model will continuously update.
+ * Anytime you add text files to `testDir`, you'll see predictions from the current model.
+ *
+ */
+object StreamingLogisticRegression {
+
+ def main(args: Array[String]) {
+
+ if (args.length != 4) {
+ System.err.println(
+ "Usage: StreamingLogisticRegression <trainingDir> <testDir> <batchDuration> <numFeatures>")
+ System.exit(1)
+ }
+
+ val conf = new SparkConf().setMaster("local").setAppName("StreamingLogisticRegression")
+ val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
+
+ val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse)
+ val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
+
+ val model = new StreamingLogisticRegressionWithSGD()
+ .setInitialWeights(Vectors.zeros(args(3).toInt))
+
+ model.trainOn(trainingData)
+ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
+
+ ssc.start()
+ ssc.awaitTermination()
+
+ }
+
+}