aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-03 18:06:48 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-03 18:06:48 -0700
commit1ffa8cb91f8badf12a8aa190dc25920715a00db7 (patch)
tree48b39ad4080062acb1d69c2c290cbc6f063e4df8 /python/pyspark/ml
parent9e25b09f8809378777ae8bbe75dca12d2c45ff4c (diff)
downloadspark-1ffa8cb91f8badf12a8aa190dc25920715a00db7.tar.gz
spark-1ffa8cb91f8badf12a8aa190dc25920715a00db7.tar.bz2
spark-1ffa8cb91f8badf12a8aa190dc25920715a00db7.zip
[SPARK-7329] [MLLIB] simplify ParamGridBuilder impl
as suggested by justinuang on #5601. Author: Xiangrui Meng <meng@databricks.com> Closes #5873 from mengxr/SPARK-7329 and squashes the following commits: d08f9cf [Xiangrui Meng] simplify tests b7a7b9b [Xiangrui Meng] simplify grid build
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r--python/pyspark/ml/tuning.py28
1 files changed, 9 insertions, 19 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index a383bd0c0d..1773ab5bdc 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -15,6 +15,8 @@
# limitations under the License.
#
+import itertools
+
__all__ = ['ParamGridBuilder']
@@ -37,14 +39,10 @@ class ParamGridBuilder(object):
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
- >>> fail_count = 0
- >>> for e in expected:
- ... if e not in output:
- ... fail_count += 1
- >>> if len(expected) != len(output):
- ... fail_count += 1
- >>> fail_count
- 0
+ >>> len(output) == len(expected)
+ True
+ >>> all([m in expected for m in output])
+ True
"""
def __init__(self):
@@ -76,17 +74,9 @@ class ParamGridBuilder(object):
Builds and returns all combinations of parameters specified
by the param grid.
"""
- param_maps = [{}]
- for (param, values) in self._param_grid.items():
- new_param_maps = []
- for value in values:
- for old_map in param_maps:
- copied_map = old_map.copy()
- copied_map[param] = value
- new_param_maps.append(copied_map)
- param_maps = new_param_maps
-
- return param_maps
+ keys = self._param_grid.keys()
+ grid_values = self._param_grid.values()
+ return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
if __name__ == "__main__":