# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from abc import ABCMeta from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer from pyspark.mllib.common import inherit_doc def _jvm(): """ Returns the JVM view associated with SparkContext. Must be called after SparkContext is initialized. """ jvm = SparkContext._jvm if jvm: return jvm else: raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") @inherit_doc class JavaWrapper(Params): """ Utility class to help create wrapper classes from Java/Scala implementations of pipeline components. """ __metaclass__ = ABCMeta #: Fully-qualified class name of the wrapped Java component. _java_class = None def _java_obj(self): """ Returns or creates a Java object. """ java_obj = _jvm() for name in self._java_class.split("."): java_obj = getattr(java_obj, name) return java_obj() def _transfer_params_to_java(self, params, java_obj): """ Transforms the embedded params and additional params to the input Java object. :param params: additional params (overwriting embedded values) :param java_obj: Java object to receive the params """ paramMap = self._merge_params(params) for param in self.params: if param in paramMap: java_obj.set(param.name, paramMap[param]) def _empty_java_param_map(self): """ Returns an empty Java ParamMap reference. """ return _jvm().org.apache.spark.ml.param.ParamMap() def _create_java_param_map(self, params, java_obj): paramMap = self._empty_java_param_map() for param, value in params.items(): if param.parent is self: paramMap.put(java_obj.getParam(param.name), value) return paramMap @inherit_doc class JavaEstimator(Estimator, JavaWrapper): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. """ __metaclass__ = ABCMeta def _create_model(self, java_model): """ Creates a model from the input Java model reference. """ return JavaModel(java_model) def _fit_java(self, dataset, params={}): """ Fits a Java model to the input dataset. :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.SchemaRDD` :param params: additional params (overwriting embedded values) :return: fitted Java model """ java_obj = self._java_obj() self._transfer_params_to_java(params, java_obj) return java_obj.fit(dataset._jdf, self._empty_java_param_map()) def fit(self, dataset, params={}): java_model = self._fit_java(dataset, params) return self._create_model(java_model) @inherit_doc class JavaTransformer(Transformer, JavaWrapper): """ Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. """ __metaclass__ = ABCMeta def transform(self, dataset, params={}): java_obj = self._java_obj() self._transfer_params_to_java({}, java_obj) java_param_map = self._create_java_param_map(params, java_obj) return DataFrame(java_obj.transform(dataset._jdf, java_param_map), dataset.sql_ctx) @inherit_doc class JavaModel(JavaTransformer): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. """ __metaclass__ = ABCMeta def __init__(self, java_model): super(JavaTransformer, self).__init__() self._java_model = java_model def _java_obj(self): return self._java_model