aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRam Sriharsha <rsriharsha@hw11853.local>2015-06-02 18:53:04 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-06-02 18:53:04 -0700
commitc3f4c3257194ba34ccd298d13ea1edcfc75f7552 (patch)
treedd1155697c003e0af5ab98c847473229bf071bab /examples
parent5cd6a63d9692d153751747e0293dc030d73a6194 (diff)
downloadspark-c3f4c3257194ba34ccd298d13ea1edcfc75f7552.tar.gz
spark-c3f4c3257194ba34ccd298d13ea1edcfc75f7552.tar.bz2
spark-c3f4c3257194ba34ccd298d13ea1edcfc75f7552.zip
[SPARK-7387] [ML] [DOC] CrossValidator example code in Python
Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #6358 from harsha2010/SPARK-7387 and squashes the following commits: 63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly aeb6bb6 [Ram Sriharsha] Python Style Fix 54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387 615e91c [Ram Sriharsha] cleanup 204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387 7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/python/ml/cross_validator.py96
-rw-r--r--examples/src/main/python/ml/simple_params_example.py4
2 files changed, 98 insertions, 2 deletions
diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py
new file mode 100644
index 0000000000..f0ca97c724
--- /dev/null
+++ b/examples/src/main/python/ml/cross_validator.py
@@ -0,0 +1,96 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.ml import Pipeline
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.evaluation import BinaryClassificationEvaluator
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating model selection using CrossValidator.
+This example also demonstrates how Pipelines are Estimators.
+Run with:
+
+ bin/spark-submit examples/src/main/python/ml/cross_validator.py
+"""
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="CrossValidatorExample")
+ sqlContext = SQLContext(sc)
+
+ # Prepare training documents, which are labeled.
+ LabeledDocument = Row("id", "text", "label")
+ training = sc.parallelize([(0, "a b c d e spark", 1.0),
+ (1, "b d", 0.0),
+ (2, "spark f g h", 1.0),
+ (3, "hadoop mapreduce", 0.0),
+ (4, "b spark who", 1.0),
+ (5, "g d a y", 0.0),
+ (6, "spark fly", 1.0),
+ (7, "was mapreduce", 0.0),
+ (8, "e spark program", 1.0),
+ (9, "a e c l", 0.0),
+ (10, "spark compile", 1.0),
+ (11, "hadoop software", 0.0)
+ ]) \
+ .map(lambda x: LabeledDocument(*x)).toDF()
+
+ # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
+ tokenizer = Tokenizer(inputCol="text", outputCol="words")
+ hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
+ lr = LogisticRegression(maxIter=10)
+ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+
+ # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
+ # This will allow us to jointly choose parameters for all Pipeline stages.
+ # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ # We use a ParamGridBuilder to construct a grid of parameters to search over.
+ # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
+ # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
+ paramGrid = ParamGridBuilder() \
+ .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
+ .addGrid(lr.regParam, [0.1, 0.01]) \
+ .build()
+
+ crossval = CrossValidator(estimator=pipeline,
+ estimatorParamMaps=paramGrid,
+ evaluator=BinaryClassificationEvaluator(),
+ numFolds=2) # use 3+ folds in practice
+
+ # Run cross-validation, and choose the best set of parameters.
+ cvModel = crossval.fit(training)
+
+ # Prepare test documents, which are unlabeled.
+ Document = Row("id", "text")
+ test = sc.parallelize([(4L, "spark i j k"),
+ (5L, "l m n"),
+ (6L, "mapreduce spark"),
+ (7L, "apache hadoop")]) \
+ .map(lambda x: Document(*x)).toDF()
+
+ # Make predictions on test documents. cvModel uses the best model found (lrModel).
+ prediction = cvModel.transform(test)
+ selected = prediction.select("id", "text", "probability", "prediction")
+ for row in selected.collect():
+ print(row)
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
index 3933d59b52..a9f29dab2d 100644
--- a/examples/src/main/python/ml/simple_params_example.py
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -41,8 +41,8 @@ if __name__ == "__main__":
# prepare training data.
# We create an RDD of LabeledPoints and convert them into a DataFrame.
- # Spark DataFrames can automatically infer the schema from named tuples
- # and LabeledPoint implements __reduce__ to behave like a named tuple.
+ # A LabeledPoint is an Object with two fields named label and features
+ # and Spark SQL identifies these fields and creates the schema appropriately.
training = sc.parallelize([
LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])),
LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])),