aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-06-02 08:51:00 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-02 08:51:00 -0700
commitbd97840d5ccc3f0bfde1e5cfc7abeac9681997ab (patch)
tree5b5a665f0c60f72f97f0019ace0977d59ba8c1d3 /python
parent445647a1a36e1e24076a9fe506492fac462c66ad (diff)
downloadspark-bd97840d5ccc3f0bfde1e5cfc7abeac9681997ab.tar.gz
spark-bd97840d5ccc3f0bfde1e5cfc7abeac9681997ab.tar.bz2
spark-bd97840d5ccc3f0bfde1e5cfc7abeac9681997ab.zip
[SPARK-7432] [MLLIB] fix flaky CrossValidator doctest
The new test uses CV to compare `maxIter=0` and `maxIter=1`, and validate on the evaluation result. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6572 from mengxr/SPARK-7432 and squashes the following commits: c236bb8 [Xiangrui Meng] fix flacky cv doctest
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/tuning.py19
1 files changed, 9 insertions, 10 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 497841b6c8..0bf988fd72 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -91,20 +91,19 @@ class CrossValidator(Estimator):
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.mllib.linalg import Vectors
>>> dataset = sqlContext.createDataFrame(
- ... [(Vectors.dense([0.0, 1.0]), 0.0),
- ... (Vectors.dense([1.0, 2.0]), 1.0),
- ... (Vectors.dense([0.55, 3.0]), 0.0),
- ... (Vectors.dense([0.45, 4.0]), 1.0),
- ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
+ ... [(Vectors.dense([0.0]), 0.0),
+ ... (Vectors.dense([0.4]), 1.0),
+ ... (Vectors.dense([0.5]), 0.0),
+ ... (Vectors.dense([0.6]), 1.0),
+ ... (Vectors.dense([1.0]), 1.0)] * 10,
... ["features", "label"])
>>> lr = LogisticRegression()
- >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
+ >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- >>> # SPARK-7432: The following test is flaky.
- >>> # cvModel = cv.fit(dataset)
- >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
- >>> # cvModel.transform(dataset).collect() == expected.collect()
+ >>> cvModel = cv.fit(dataset)
+ >>> evaluator.evaluate(cvModel.transform(dataset))
+ 0.8333...
"""
# a placeholder to make it appear in the generated doc