diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-18 12:02:18 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-18 12:02:18 -0700 |
commit | 9c7e802a5a2b8cd3eb77642f84c54a8e976fc996 (patch) | |
tree | 2e3b7e367f57b64ef46733ee8b64aa258e58cca8 /python/pyspark/ml/tuning.py | |
parent | 56ede88485cfca90974425fcb603b257be47229b (diff) | |
download | spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.gz spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.bz2 spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.zip |
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python
This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes:
1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively.
2. Accept a list of param maps in `fit`.
3. Use parent uid and name to identify param.
jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #6088 from mengxr/SPARK-7380 and squashes the following commits:
413c463 [Xiangrui Meng] remove unnecessary doc
4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380
611c719 [Xiangrui Meng] fix python style
68862b8 [Xiangrui Meng] update _java_obj initialization
927ad19 [Xiangrui Meng] fix ml/tests.py
0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer
9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests
c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params
7e0d27f [Xiangrui Meng] merge master
46840fb [Xiangrui Meng] update wrappers
b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap
46cb6ed [Xiangrui Meng] merge master
a163413 [Xiangrui Meng] fix style
1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380
9630eae [Xiangrui Meng] fix Identifiable._randomUID
13bd70a [Xiangrui Meng] update ml/tests.py
64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl
02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python
66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui
7431272 [Joseph K. Bradley] Rebased with master
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r-- | python/pyspark/ml/tuning.py | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 86f4dc7368..497841b6c8 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -155,7 +155,7 @@ class CrossValidator(Estimator): """ Sets the value of :py:attr:`estimator`. """ - self.paramMap[self.estimator] = value + self._paramMap[self.estimator] = value return self def getEstimator(self): @@ -168,7 +168,7 @@ class CrossValidator(Estimator): """ Sets the value of :py:attr:`estimatorParamMaps`. """ - self.paramMap[self.estimatorParamMaps] = value + self._paramMap[self.estimatorParamMaps] = value return self def getEstimatorParamMaps(self): @@ -181,7 +181,7 @@ class CrossValidator(Estimator): """ Sets the value of :py:attr:`evaluator`. """ - self.paramMap[self.evaluator] = value + self._paramMap[self.evaluator] = value return self def getEvaluator(self): @@ -194,7 +194,7 @@ class CrossValidator(Estimator): """ Sets the value of :py:attr:`numFolds`. """ - self.paramMap[self.numFolds] = value + self._paramMap[self.numFolds] = value return self def getNumFolds(self): @@ -203,13 +203,12 @@ class CrossValidator(Estimator): """ return self.getOrDefault(self.numFolds) - def fit(self, dataset, params={}): - paramMap = self.extractParamMap(params) - est = paramMap[self.estimator] - epm = paramMap[self.estimatorParamMaps] + def _fit(self, dataset): + est = self.getOrDefault(self.estimator) + epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) - eva = paramMap[self.evaluator] - nFolds = paramMap[self.numFolds] + eva = self.getOrDefault(self.evaluator) + nFolds = self.getOrDefault(self.numFolds) h = 1.0 / nFolds randCol = self.uid + "_rand" df = dataset.select("*", rand(0).alias(randCol)) @@ -229,6 +228,15 @@ class CrossValidator(Estimator): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) + def copy(self, extra={}): + newCV = Params.copy(self, extra) + if self.isSet(self.estimator): + newCV.setEstimator(self.getEstimator().copy(extra)) + # estimatorParamMaps remain the same + if self.isSet(self.evaluator): + newCV.setEvaluator(self.getEvaluator().copy(extra)) + return newCV + class CrossValidatorModel(Model): """ @@ -240,8 +248,19 @@ class CrossValidatorModel(Model): #: best model from cross validation self.bestModel = bestModel - def transform(self, dataset, params={}): - return self.bestModel.transform(dataset, params) + def _transform(self, dataset): + return self.bestModel.transform(dataset) + + def copy(self, extra={}): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies the underlying bestModel, + creates a deep copy of the embedded paramMap, and + copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + return CrossValidatorModel(self.bestModel.copy(extra)) if __name__ == "__main__": |