aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRam Sriharsha <rsriharsha@hw11853.local>2015-05-29 15:22:26 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-29 15:22:26 -0700
commitdbf8ff38de0f95f467b874a5b527dcf59439efe8 (patch)
treeaf580fb9fffa5f339d7470b2fe47b783335a999e /examples
parent5fb97dca9bcfc29ac33823554c8783997e811b99 (diff)
downloadspark-dbf8ff38de0f95f467b874a5b527dcf59439efe8.tar.gz
spark-dbf8ff38de0f95f467b874a5b527dcf59439efe8.tar.bz2
spark-dbf8ff38de0f95f467b874a5b527dcf59439efe8.zip
[SPARK-6013] [ML] Add more Python ML examples for spark.ml
Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #6443 from harsha2010/SPARK-6013 and squashes the following commits: 732506e [Ram Sriharsha] Code Review Feedback 121c211 [Ram Sriharsha] python style fix 5f9b8c3 [Ram Sriharsha] python style fixes 925ca86 [Ram Sriharsha] Simple Params Example 8b372b1 [Ram Sriharsha] GBT Example 965ec14 [Ram Sriharsha] Random Forest Example
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java2
-rw-r--r--examples/src/main/python/ml/gradient_boosted_trees.py83
-rw-r--r--examples/src/main/python/ml/random_forest_example.py87
-rw-r--r--examples/src/main/python/ml/simple_params_example.py98
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala2
5 files changed, 270 insertions, 2 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 29158d5c85..dac649d1d5 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -97,7 +97,7 @@ public class JavaSimpleParamsExample {
DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
- // LogisticRegression.transform will only use the 'features' column.
+ // LogisticRegressionModel.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
DataFrame results = model2.transform(test);
diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py
new file mode 100644
index 0000000000..6446f0fe5e
--- /dev/null
+++ b/examples/src/main/python/ml/gradient_boosted_trees.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import GBTClassifier
+from pyspark.ml.feature import StringIndexer
+from pyspark.ml.regression import GBTRegressor
+from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline.
+Note: GBTClassifier only supports binary classification currently
+Run with:
+ bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py
+"""
+
+
+def testClassification(train, test):
+ # Train a GradientBoostedTrees model.
+
+ rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel")
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = BinaryClassificationMetrics(predictionAndLabels)
+ print("AUC %.3f" % metrics.areaUnderROC)
+
+
+def testRegression(train, test):
+ # Train a GradientBoostedTrees model.
+
+ rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel")
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = RegressionMetrics(predictionAndLabels)
+ print("rmse %.3f" % metrics.rootMeanSquaredError)
+ print("r2 %.3f" % metrics.r2)
+ print("mae %.3f" % metrics.meanAbsoluteError)
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: gradient_boosted_trees", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonGBTExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [train, test] = td.randomSplit([0.7, 0.3])
+ testClassification(train, test)
+ testRegression(train, test)
+ sc.stop()
diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py
new file mode 100644
index 0000000000..c7730e1bfa
--- /dev/null
+++ b/examples/src/main/python/ml/random_forest_example.py
@@ -0,0 +1,87 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import RandomForestClassifier
+from pyspark.ml.feature import StringIndexer
+from pyspark.ml.regression import RandomForestRegressor
+from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating a RandomForest Classification/Regression Pipeline.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/random_forest_example.py
+"""
+
+
+def testClassification(train, test):
+ # Train a RandomForest model.
+ # Setting featureSubsetStrategy="auto" lets the algorithm choose.
+ # Note: Use larger numTrees in practice.
+
+ rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4)
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = MulticlassMetrics(predictionAndLabels)
+ print("weighted f-measure %.3f" % metrics.weightedFMeasure())
+ print("precision %s" % metrics.precision())
+ print("recall %s" % metrics.recall())
+
+
+def testRegression(train, test):
+ # Train a RandomForest model.
+ # Note: Use larger numTrees in practice.
+
+ rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4)
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = RegressionMetrics(predictionAndLabels)
+ print("rmse %.3f" % metrics.rootMeanSquaredError)
+ print("r2 %.3f" % metrics.r2)
+ print("mae %.3f" % metrics.meanAbsoluteError)
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: random_forest_example", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonRandomForestExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [train, test] = td.randomSplit([0.7, 0.3])
+ testClassification(train, test)
+ testRegression(train, test)
+ sc.stop()
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
new file mode 100644
index 0000000000..3933d59b52
--- /dev/null
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import pprint
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import LogisticRegression
+from pyspark.mllib.linalg import DenseVector
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.sql import SQLContext
+
+"""
+A simple example demonstrating ways to specify parameters for Estimators and Transformers.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/simple_params_example.py
+"""
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: simple_params_example", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonSimpleParamsExample")
+ sqlContext = SQLContext(sc)
+
+ # prepare training data.
+ # We create an RDD of LabeledPoints and convert them into a DataFrame.
+ # Spark DataFrames can automatically infer the schema from named tuples
+ # and LabeledPoint implements __reduce__ to behave like a named tuple.
+ training = sc.parallelize([
+ LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])),
+ LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])),
+ LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])),
+ LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF()
+
+ # Create a LogisticRegression instance with maxIter = 10.
+ # This instance is an Estimator.
+ lr = LogisticRegression(maxIter=10)
+ # Print out the parameters, documentation, and any default values.
+ print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
+
+ # We may also set parameters using setter methods.
+ lr.setRegParam(0.01)
+
+ # Learn a LogisticRegression model. This uses the parameters stored in lr.
+ model1 = lr.fit(training)
+
+ # Since model1 is a Model (i.e., a Transformer produced by an Estimator),
+ # we can view the parameters it used during fit().
+ # This prints the parameter (name: value) pairs, where names are unique IDs for this
+ # LogisticRegression instance.
+ print("Model 1 was fit using parameters:\n")
+ pprint.pprint(model1.extractParamMap())
+
+ # We may alternatively specify parameters using a parameter map.
+ # paramMap overrides all lr parameters set earlier.
+ paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"}
+
+ # Now learn a new model using the new parameters.
+ model2 = lr.fit(training, paramMap)
+ print("Model 2 was fit using parameters:\n")
+ pprint.pprint(model2.extractParamMap())
+
+ # prepare test data.
+ test = sc.parallelize([
+ LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])),
+ LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])),
+ LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF()
+
+ # Make predictions on test data using the Transformer.transform() method.
+ # LogisticRegressionModel.transform will only use the 'features' column.
+ # Note that model2.transform() outputs a 'myProbability' column instead of the usual
+ # 'probability' column since we renamed the lr.probabilityCol parameter previously.
+ result = model2.transform(test) \
+ .select("features", "label", "myProbability", "prediction") \
+ .collect()
+
+ for row in result:
+ print("features=%s,label=%s -> prob=%s, prediction=%s"
+ % (row.features, row.label, row.myProbability, row.prediction))
+
+ sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index e8a991f50e..a0561e2573 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -87,7 +87,7 @@ object SimpleParamsExample {
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
// Make predictions on test data using the Transformer.transform() method.
- // LogisticRegression.transform will only use the 'features' column.
+ // LogisticRegressionModel.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
model2.transform(test.toDF())