aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-04-06 12:07:47 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 12:07:47 -0700
commit9c6556c5f8ab013b36312db4bf02c4c6d965a535 (patch)
treee4200c088c376f26f27de4f3a96c99006dd99b20 /python/pyspark/ml/wrapper.py
parentbb1fa5b2182f384cb711fc2be45b0f1a8c466ed6 (diff)
downloadspark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.tar.gz
spark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.tar.bz2
spark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.zip
[SPARK-13430][PYSPARK][ML] Python API for training summaries of linear and logistic regression
## What changes were proposed in this pull request? Adding Python API for training summaries of LogisticRegression and LinearRegression in PySpark ML. ## How was this patch tested? Added unit tests to exercise the api calls for the summary classes. Also, manually verified values are expected and match those from Scala directly. Author: Bryan Cutler <cutlerb@gmail.com> Closes #11621 from BryanCutler/pyspark-ml-summary-SPARK-13430.
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py30
1 files changed, 23 insertions, 7 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index ca93bf7d7d..a2cf2296fb 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -213,8 +213,30 @@ class JavaTransformer(Transformer, JavaWrapper):
return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
+class JavaCallable(object):
+ """
+ Wrapper for a plain object in JVM to make Java calls, can be used
+ as a mixin to another class that defines a _java_obj wrapper
+ """
+ def __init__(self, java_obj=None, sc=None):
+ super(JavaCallable, self).__init__()
+ self._sc = sc if sc is not None else SparkContext._active_spark_context
+ # if this class is a mixin and _java_obj is already defined then don't initialize
+ if java_obj is not None or not hasattr(self, "_java_obj"):
+ self._java_obj = java_obj
+
+ def __del__(self):
+ if self._java_obj is not None:
+ self._sc._gateway.detach(self._java_obj)
+
+ def _call_java(self, name, *args):
+ m = getattr(self._java_obj, name)
+ java_args = [_py2java(self._sc, arg) for arg in args]
+ return _java2py(self._sc, m(*java_args))
+
+
@inherit_doc
-class JavaModel(Model, JavaTransformer):
+class JavaModel(Model, JavaCallable, JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations. Subclasses should inherit this class before
@@ -259,9 +281,3 @@ class JavaModel(Model, JavaTransformer):
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
return that
-
- def _call_java(self, name, *args):
- m = getattr(self._java_obj, name)
- sc = SparkContext._active_spark_context
- java_args = [_py2java(sc, arg) for arg in args]
- return _java2py(sc, m(*java_args))