aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
blob: 9e12ddc3d9b8f2e2557a66bdb73ebb074e5afcd8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from abc import ABCMeta

from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer
from pyspark.ml.util import inherit_doc


def _jvm():
    """
    Returns the JVM view associated with SparkContext. Must be called
    after SparkContext is initialized.
    """
    jvm = SparkContext._jvm
    if jvm:
        return jvm
    else:
        raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")


@inherit_doc
class JavaWrapper(Params):
    """
    Utility class to help create wrapper classes from Java/Scala
    implementations of pipeline components.
    """

    __metaclass__ = ABCMeta

    #: Fully-qualified class name of the wrapped Java component.
    _java_class = None

    def _java_obj(self):
        """
        Returns or creates a Java object.
        """
        java_obj = _jvm()
        for name in self._java_class.split("."):
            java_obj = getattr(java_obj, name)
        return java_obj()

    def _transfer_params_to_java(self, params, java_obj):
        """
        Transforms the embedded params and additional params to the
        input Java object.
        :param params: additional params (overwriting embedded values)
        :param java_obj: Java object to receive the params
        """
        paramMap = self._merge_params(params)
        for param in self.params:
            if param in paramMap:
                java_obj.set(param.name, paramMap[param])

    def _empty_java_param_map(self):
        """
        Returns an empty Java ParamMap reference.
        """
        return _jvm().org.apache.spark.ml.param.ParamMap()

    def _create_java_param_map(self, params, java_obj):
        paramMap = self._empty_java_param_map()
        for param, value in params.items():
            if param.parent is self:
                paramMap.put(java_obj.getParam(param.name), value)
        return paramMap


@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):
    """
    Base class for :py:class:`Estimator`s that wrap Java/Scala
    implementations.
    """

    __metaclass__ = ABCMeta

    def _create_model(self, java_model):
        """
        Creates a model from the input Java model reference.
        """
        return JavaModel(java_model)

    def _fit_java(self, dataset, params={}):
        """
        Fits a Java model to the input dataset.
        :param dataset: input dataset, which is an instance of
                        :py:class:`pyspark.sql.SchemaRDD`
        :param params: additional params (overwriting embedded values)
        :return: fitted Java model
        """
        java_obj = self._java_obj()
        self._transfer_params_to_java(params, java_obj)
        return java_obj.fit(dataset._jdf, self._empty_java_param_map())

    def fit(self, dataset, params={}):
        java_model = self._fit_java(dataset, params)
        return self._create_model(java_model)


@inherit_doc
class JavaTransformer(Transformer, JavaWrapper):
    """
    Base class for :py:class:`Transformer`s that wrap Java/Scala
    implementations.
    """

    __metaclass__ = ABCMeta

    def transform(self, dataset, params={}):
        java_obj = self._java_obj()
        self._transfer_params_to_java({}, java_obj)
        java_param_map = self._create_java_param_map(params, java_obj)
        return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
                         dataset.sql_ctx)


@inherit_doc
class JavaModel(JavaTransformer):
    """
    Base class for :py:class:`Model`s that wrap Java/Scala
    implementations.
    """

    __metaclass__ = ABCMeta

    def __init__(self, java_model):
        super(JavaTransformer, self).__init__()
        self._java_model = java_model

    def _java_obj(self):
        return self._java_model