aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-08 11:16:04 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-08 11:16:04 -0700
commit65afd3ce8b8a0b00f4ea8294eac14b72e964872d (patch)
treed0d210b436b9f0bb9a0c9eff382bc9c077224a42 /python
parentf5ff4a84c4c75143086aae7d38730156bee35933 (diff)
downloadspark-65afd3ce8b8a0b00f4ea8294eac14b72e964872d.tar.gz
spark-65afd3ce8b8a0b00f4ea8294eac14b72e964872d.tar.bz2
spark-65afd3ce8b8a0b00f4ea8294eac14b72e964872d.zip
[SPARK-7474] [MLLIB] update ParamGridBuilder doctest
Multiline commands are properly handled in this PR. oefirouz ![screen shot 2015-05-07 at 10 53 25 pm](https://cloud.githubusercontent.com/assets/829644/7531290/02ad2fd4-f50c-11e4-8c04-e58d1a61ad69.png) Author: Xiangrui Meng <meng@databricks.com> Closes #6001 from mengxr/SPARK-7474 and squashes the following commits: b94b11d [Xiangrui Meng] update ParamGridBuilder doctest
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/tuning.py28
1 files changed, 13 insertions, 15 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 1e04c37fca..28e3727f2c 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -27,24 +27,22 @@ __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']
class ParamGridBuilder(object):
- """
+ r"""
Builder for a param grid used in grid search-based model selection.
- >>> from classification import LogisticRegression
+ >>> from pyspark.ml.classification import LogisticRegression
>>> lr = LogisticRegression()
- >>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \
- .baseOn([lr.predictionCol, 'p']) \
- .addGrid(lr.regParam, [1.0, 2.0, 3.0]) \
- .addGrid(lr.maxIter, [1, 5]) \
- .addGrid(lr.featuresCol, ['f']) \
- .build()
- >>> expected = [ \
-{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
-{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
-{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
-{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
-{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
-{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
+ >>> output = ParamGridBuilder() \
+ ... .baseOn({lr.labelCol: 'l'}) \
+ ... .baseOn([lr.predictionCol, 'p']) \
+ ... .addGrid(lr.regParam, [1.0, 2.0]) \
+ ... .addGrid(lr.maxIter, [1, 5]) \
+ ... .build()
+ >>> expected = [
+ ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
+ ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
+ ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},
+ ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
>>> len(output) == len(expected)
True
>>> all([m in expected for m in output])