diff options
author | Bryan Cutler <cutlerb@gmail.com> | 2016-04-06 12:07:47 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-06 12:07:47 -0700 |
commit | 9c6556c5f8ab013b36312db4bf02c4c6d965a535 (patch) | |
tree | e4200c088c376f26f27de4f3a96c99006dd99b20 /python/pyspark/ml/wrapper.py | |
parent | bb1fa5b2182f384cb711fc2be45b0f1a8c466ed6 (diff) | |
download | spark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.tar.gz spark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.tar.bz2 spark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.zip |
[SPARK-13430][PYSPARK][ML] Python API for training summaries of linear and logistic regression
## What changes were proposed in this pull request?
Adding Python API for training summaries of LogisticRegression and LinearRegression in PySpark ML.
## How was this patch tested?
Added unit tests to exercise the api calls for the summary classes. Also, manually verified values are expected and match those from Scala directly.
Author: Bryan Cutler <cutlerb@gmail.com>
Closes #11621 from BryanCutler/pyspark-ml-summary-SPARK-13430.
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r-- | python/pyspark/ml/wrapper.py | 30 |
1 files changed, 23 insertions, 7 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index ca93bf7d7d..a2cf2296fb 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -213,8 +213,30 @@ class JavaTransformer(Transformer, JavaWrapper): return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) +class JavaCallable(object): + """ + Wrapper for a plain object in JVM to make Java calls, can be used + as a mixin to another class that defines a _java_obj wrapper + """ + def __init__(self, java_obj=None, sc=None): + super(JavaCallable, self).__init__() + self._sc = sc if sc is not None else SparkContext._active_spark_context + # if this class is a mixin and _java_obj is already defined then don't initialize + if java_obj is not None or not hasattr(self, "_java_obj"): + self._java_obj = java_obj + + def __del__(self): + if self._java_obj is not None: + self._sc._gateway.detach(self._java_obj) + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + java_args = [_py2java(self._sc, arg) for arg in args] + return _java2py(self._sc, m(*java_args)) + + @inherit_doc -class JavaModel(Model, JavaTransformer): +class JavaModel(Model, JavaCallable, JavaTransformer): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before @@ -259,9 +281,3 @@ class JavaModel(Model, JavaTransformer): that._java_obj = self._java_obj.copy(self._empty_java_param_map()) that._transfer_params_to_java() return that - - def _call_java(self, name, *args): - m = getattr(self._java_obj, name) - sc = SparkContext._active_spark_context - java_args = [_py2java(sc, arg) for arg in args] - return _java2py(sc, m(*java_args)) |