diff options
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r-- | python/pyspark/ml/wrapper.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index f5ac2a3986..dda6c6aba3 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,7 +21,7 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model -from pyspark.mllib.common import inherit_doc +from pyspark.mllib.common import inherit_doc, _java2py, _py2java def _jvm(): @@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer): def _java_obj(self): return self._java_model + def _call_java(self, name, *args): + m = getattr(self._java_model, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) + @inherit_doc class JavaEvaluator(Evaluator, JavaWrapper): |