aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorShuo Xiang <shuoxiangpub@gmail.com>2015-06-29 23:50:34 -0700
committerDB Tsai <dbt@netflix.com>2015-06-29 23:50:34 -0700
commit5452457410ffe881773f2f2cdcdc752467b19720 (patch)
tree9ef58df0d78b0e742dc1f8d8278990d20e288251 /examples
parent12671dd5e468beedc2681ff2bdf95fba81f8f29c (diff)
downloadspark-5452457410ffe881773f2f2cdcdc752467b19720.tar.gz
spark-5452457410ffe881773f2f2cdcdc752467b19720.tar.bz2
spark-5452457410ffe881773f2f2cdcdc752467b19720.zip
[SPARK-8551] [ML] Elastic net python code example
Author: Shuo Xiang <shuoxiangpub@gmail.com> Closes #6946 from coderxiang/en-java-code-example and squashes the following commits: 7a4bdf8 [Shuo Xiang] address comments cddb02b [Shuo Xiang] add elastic net python example code f4fa534 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 6ad4865 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 180b496 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' aa0717d [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 5f109b4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' c5c5bfe [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 98804c9 [Shuo Xiang] fix bug in topBykey and update test
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/python/ml/logistic_regression.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py
new file mode 100644
index 0000000000..55afe1b207
--- /dev/null
+++ b/examples/src/main/python/ml/logistic_regression.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import LogisticRegression
+from pyspark.mllib.evaluation import MulticlassMetrics
+from pyspark.ml.feature import StringIndexer
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import SQLContext
+
+"""
+A simple example demonstrating a logistic regression with elastic net regularization Pipeline.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/logistic_regression.py
+"""
+
+if __name__ == "__main__":
+
+ if len(sys.argv) > 1:
+ print("Usage: logistic_regression", file=sys.stderr)
+ exit(-1)
+
+ sc = SparkContext(appName="PythonLogisticRegressionExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [training, test] = td.randomSplit([0.7, 0.3])
+
+ lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel")
+ lr.setElasticNetParam(0.8)
+
+ # Fit the model
+ lrModel = lr.fit(training)
+
+ predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = MulticlassMetrics(predictionAndLabels)
+ print("weighted f-measure %.3f" % metrics.weightedFMeasure())
+ print("precision %s" % metrics.precision())
+ print("recall %s" % metrics.recall())
+
+ sc.stop()