aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/pipeline.py
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-04-13 14:08:57 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 14:08:57 -0700
commitfc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4 (patch)
tree6ee38d5b95cb6bc5548c6bb1b8da528aa46ce1a9 /python/pyspark/ml/pipeline.py
parent781df499836e4216939e0febdcd5f89d30645759 (diff)
downloadspark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.tar.gz
spark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.tar.bz2
spark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.zip
[SPARK-14472][PYSPARK][ML] Cleanup ML JavaWrapper and related class hierarchy
Currently, JavaWrapper is only a wrapper class for pipeline classes that have Params and JavaCallable is a separate mixin that provides methods to make Java calls. This change simplifies the class structure and to define the Java wrapper in a plain base class along with methods to make Java calls. Also, renames Java wrapper classes to better reflect their purpose. Ran existing Python ml tests and generated documentation to test this change. Author: Bryan Cutler <cutlerb@gmail.com> Closes #12304 from BryanCutler/pyspark-cleanup-JavaWrapper-SPARK-14472.
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r--python/pyspark/ml/pipeline.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 2b5504bc29..9d654e8b0f 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark import since
from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.common import inherit_doc
@@ -177,7 +177,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
# Create a new instance of this stage.
py_stage = cls()
# Load information from java_stage to the instance.
- py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+ py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()]
py_stage.setStages(py_stages)
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -195,7 +195,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
for idx, stage in enumerate(self.getStages()):
java_stages[idx] = stage._to_java()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
_java_obj.setStages(java_stages)
return _java_obj
@@ -275,7 +275,7 @@ class PipelineModel(Model, MLReadable, MLWritable):
Used for ML persistence.
"""
# Load information from java_stage to the instance.
- py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+ py_stages = [JavaParams._from_java(s) for s in java_stage.stages()]
# Create a new instance of this stage.
py_stage = cls(py_stages)
py_stage._resetUid(java_stage.uid())
@@ -295,6 +295,6 @@ class PipelineModel(Model, MLReadable, MLWritable):
java_stages[idx] = stage._to_java()
_java_obj =\
- JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
+ JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
return _java_obj