aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tuning.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r--python/pyspark/ml/tuning.py26
1 files changed, 13 insertions, 13 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index ea8c61b7ef..456d79d897 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
from pyspark.mllib.common import inherit_doc, _py2java
@@ -148,8 +148,8 @@ class ValidatorParams(HasSeed):
"""
# Load information from java_stage to the instance.
- estimator = JavaWrapper._from_java(java_stage.getEstimator())
- evaluator = JavaWrapper._from_java(java_stage.getEvaluator())
+ estimator = JavaParams._from_java(java_stage.getEstimator())
+ evaluator = JavaParams._from_java(java_stage.getEvaluator())
epms = [estimator._transfer_param_map_from_java(epm)
for epm in java_stage.getEstimatorParamMaps()]
return estimator, epms, evaluator
@@ -329,7 +329,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
@@ -393,7 +393,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
# Load information from java_stage to the instance.
- bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel)\
@@ -410,10 +410,10 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
sc = SparkContext._active_spark_context
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
- self.uid,
- self.bestModel._to_java(),
- _py2java(sc, []))
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
_java_obj.set("evaluator", evaluator)
@@ -574,8 +574,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
- self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+ self.uid)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
@@ -639,7 +639,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
# Load information from java_stage to the instance.
- bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = \
super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
@@ -657,7 +657,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
sc = SparkContext._active_spark_context
- _java_obj = JavaWrapper._new_java_obj(
+ _java_obj = JavaParams._new_java_obj(
"org.apache.spark.ml.tuning.TrainValidationSplitModel",
self.uid,
self.bestModel._to_java(),