aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorOmede Firouz <ofirouz@palantir.com>2015-05-03 11:42:02 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-03 11:42:02 -0700
commitf4af92550cb90e47a12d4625fa615dd2b1587d42 (patch)
tree59b1aaff420319f43363f5697118887057a7ac05 /python
parent49549d5a1a867c3ba25f5e4aec351d4102444bc0 (diff)
downloadspark-f4af92550cb90e47a12d4625fa615dd2b1587d42.tar.gz
spark-f4af92550cb90e47a12d4625fa615dd2b1587d42.tar.bz2
spark-f4af92550cb90e47a12d4625fa615dd2b1587d42.zip
[SPARK-7022] [PYSPARK] [ML] Add ML.Tuning.ParamGridBuilder to PySpark
Author: Omede Firouz <ofirouz@palantir.com> Author: Omede <omedefirouz@gmail.com> Closes #5601 from oefirouz/paramgrid and squashes the following commits: c9e2481 [Omede Firouz] Make test a doctest 9a8ce22 [Omede] Fix linter issues 8b8a6d2 [Omede Firouz] [SPARK-7022][PySpark][ML] Add ML.Tuning.ParamGridBuilder to PySpark
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/tuning.py94
-rwxr-xr-xpython/run-tests1
2 files changed, 95 insertions, 0 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
new file mode 100644
index 0000000000..a383bd0c0d
--- /dev/null
+++ b/python/pyspark/ml/tuning.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+__all__ = ['ParamGridBuilder']
+
+
+class ParamGridBuilder(object):
+ """
+ Builder for a param grid used in grid search-based model selection.
+
+ >>> from classification import LogisticRegression
+ >>> lr = LogisticRegression()
+ >>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \
+ .baseOn([lr.predictionCol, 'p']) \
+ .addGrid(lr.regParam, [1.0, 2.0, 3.0]) \
+ .addGrid(lr.maxIter, [1, 5]) \
+ .addGrid(lr.featuresCol, ['f']) \
+ .build()
+ >>> expected = [ \
+{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
+{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
+{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
+{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
+{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
+{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
+ >>> fail_count = 0
+ >>> for e in expected:
+ ... if e not in output:
+ ... fail_count += 1
+ >>> if len(expected) != len(output):
+ ... fail_count += 1
+ >>> fail_count
+ 0
+ """
+
+ def __init__(self):
+ self._param_grid = {}
+
+ def addGrid(self, param, values):
+ """
+ Sets the given parameters in this grid to fixed values.
+ """
+ self._param_grid[param] = values
+
+ return self
+
+ def baseOn(self, *args):
+ """
+ Sets the given parameters in this grid to fixed values.
+ Accepts either a parameter dictionary or a list of (parameter, value) pairs.
+ """
+ if isinstance(args[0], dict):
+ self.baseOn(*args[0].items())
+ else:
+ for (param, value) in args:
+ self.addGrid(param, [value])
+
+ return self
+
+ def build(self):
+ """
+ Builds and returns all combinations of parameters specified
+ by the param grid.
+ """
+ param_maps = [{}]
+ for (param, values) in self._param_grid.items():
+ new_param_maps = []
+ for value in values:
+ for old_map in param_maps:
+ copied_map = old_map.copy()
+ copied_map[param] = value
+ new_param_maps.append(copied_map)
+ param_maps = new_param_maps
+
+ return param_maps
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/run-tests b/python/run-tests
index 88b63b84fd..0e0eee3564 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -98,6 +98,7 @@ function run_ml_tests() {
echo "Run ml tests ..."
run_test "pyspark/ml/feature.py"
run_test "pyspark/ml/classification.py"
+ run_test "pyspark/ml/tuning.py"
run_test "pyspark/ml/tests.py"
}