diff options
author | Bryan Cutler <cutlerb@gmail.com> | 2016-04-13 14:08:57 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-13 14:08:57 -0700 |
commit | fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4 (patch) | |
tree | 6ee38d5b95cb6bc5548c6bb1b8da528aa46ce1a9 /python/pyspark/ml/tuning.py | |
parent | 781df499836e4216939e0febdcd5f89d30645759 (diff) | |
download | spark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.tar.gz spark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.tar.bz2 spark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.zip |
[SPARK-14472][PYSPARK][ML] Cleanup ML JavaWrapper and related class hierarchy
Currently, JavaWrapper is only a wrapper class for pipeline classes that have Params and JavaCallable is a separate mixin that provides methods to make Java calls. This change simplifies the class structure and to define the Java wrapper in a plain base class along with methods to make Java calls. Also, renames Java wrapper classes to better reflect their purpose.
Ran existing Python ml tests and generated documentation to test this change.
Author: Bryan Cutler <cutlerb@gmail.com>
Closes #12304 from BryanCutler/pyspark-cleanup-JavaWrapper-SPARK-14472.
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r-- | python/pyspark/ml/tuning.py | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ea8c61b7ef..456d79d897 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand from pyspark.mllib.common import inherit_doc, _py2java @@ -148,8 +148,8 @@ class ValidatorParams(HasSeed): """ # Load information from java_stage to the instance. - estimator = JavaWrapper._from_java(java_stage.getEstimator()) - evaluator = JavaWrapper._from_java(java_stage.getEvaluator()) + estimator = JavaParams._from_java(java_stage.getEstimator()) + evaluator = JavaParams._from_java(java_stage.getEvaluator()) epms = [estimator._transfer_param_map_from_java(epm) for epm in java_stage.getEstimatorParamMaps()] return estimator, epms, evaluator @@ -329,7 +329,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -393,7 +393,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. py_stage = cls(bestModel=bestModel)\ @@ -410,10 +410,10 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", - self.uid, - self.bestModel._to_java(), - _py2java(sc, [])) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -574,8 +574,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", - self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -639,7 +639,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = \ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. @@ -657,7 +657,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj( + _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), |