aboutsummaryrefslogtreecommitdiff
path: root/beliefs/inference/belief_propagation.py
blob: e6e7b18aeedfa5791aa55c39ae639c931f61ab61 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import numpy as np
from collections import namedtuple
import logging

from beliefs.models.belief_update_node_model import (
    InvalidLambdaMsgToParent,
    BeliefUpdateNodeModel
)
from beliefs.utils.math_helper import is_kronecker_delta


logger = logging.getLogger(__name__)


MsgPassers = namedtuple('MsgPassers', ['msg_receiver', 'msg_sender'])


class ConflictingEvidenceError(Exception):
    """Failed to run belief propagation on label graph because of conflicting evidence."""
    def __init__(self, evidence):
        message = (
            "Can't run belief propagation with conflicting evidence: {}"
            .format(evidence)
        )
        super().__init__(message)


class BeliefPropagation:
    def __init__(self, model, inplace=True):
        """
        Args
            model: an instance of BeliefUpdateNodeModel
            inplace: bool,
                modify in-place the nodes in the model during belief propagation
        """
        if not isinstance(model, BeliefUpdateNodeModel):
            raise TypeError("Model must be an instance of BeliefUpdateNodeModel")
        if inplace is False:
            self.model = model.copy()
        else:
            self.model = model

    def _belief_propagation(self, nodes_to_update, evidence):
        """
        Implementation of Pearl's belief propagation algorithm for polytrees.
        ref: "Fusion, Propagation, and Structuring in Belief Networks"
             Artificial Intelligence 29 (1986) 241-288

        Args
            nodes_to_update: list,
                 list of MsgPasser namedtuples.
            evidence: dict,
                 a dict key, value pair as {var: state_of_var observed}
        """
        if len(nodes_to_update) == 0:
            return

        node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop()
        logging.debug("Node: %s", node_to_update_label_id)

        node = self.model.nodes_dict[node_to_update_label_id]

        # exclude the message sender (either a parent or child) from getting an
        # outgoing msg from the node to update
        parent_ids = set(node.parents) - set([msg_sender_label_id])
        child_ids = set(node.children) - set([msg_sender_label_id])
        logging.debug("parent_ids: %s", str(parent_ids))
        logging.debug("child_ids: %s", str(child_ids))

        if msg_sender_label_id is not None:
            # update triggered by receiving a message, not pinning to evidence
            assert len(node.parents) + len(node.children) - 1 == len(parent_ids) + len(child_ids)

        if node_to_update_label_id not in evidence:
            node.compute_and_update_pi_agg()
            logging.debug("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values))
            node.compute_and_update_lambda_agg()
            logging.debug("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values))

        for parent_id in parent_ids:
            try:
                new_lambda_msg = node.compute_lambda_msg_to_parent(parent_k=parent_id)
            except InvalidLambdaMsgToParent:
                raise ConflictingEvidenceError(evidence=evidence)

            parent_node = self.model.nodes_dict[parent_id]
            parent_node.update_lambda_msg_from_child(child=node_to_update_label_id,
                                                     new_value=new_lambda_msg)
            nodes_to_update.add(MsgPassers(msg_receiver=parent_id,
                                           msg_sender=node_to_update_label_id))

        for child_id in child_ids:
            new_pi_msg = node.compute_pi_msg_to_child(child_k=child_id)
            child_node = self.model.nodes_dict[child_id]
            child_node.update_pi_msg_from_parent(parent=node_to_update_label_id,
                                                 new_value=new_pi_msg)
            nodes_to_update.add(MsgPassers(msg_receiver=child_id,
                                           msg_sender=node_to_update_label_id))
        self._belief_propagation(nodes_to_update, evidence)

    def initialize_model(self):
        """
        1. Apply boundary conditions:
            - Set pi_agg equal to prior probabilities for root nodes.
            - Set lambda_agg equal to vector of ones for leaf nodes.

        2. Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as
           actually passing lambda messages up from leaf nodes to root nodes).
        3. Calculate pi_agg and pi_received_msgs for all nodes without evidence.
           (Without evidence, belief equals pi_agg.)
        """
        self.model.set_boundary_conditions()

        for node in self.model.nodes_dict.values():
            ones_vector = np.ones([node.cardinality])
            node.update_lambda_agg(ones_vector)

            for child in node.lambda_received_msgs.keys():
                node.update_lambda_msg_from_child(child=child,
                                                  new_value=ones_vector)
        logging.debug("Finished initializing Lambda(x) and lambda_received_msgs per node.")

        logging.debug("Start downward sweep from nodes.  Sending Pi messages only.")
        topdown_order = self.model.get_topologically_sorted_nodes(reverse=False)

        for node_id in topdown_order:
            logging.debug('label in iteration through top-down order: %s', str(node_id))

            node_sending_msg = self.model.nodes_dict[node_id]
            child_ids = node_sending_msg.children

            if node_sending_msg.pi_agg.values is None:
                node_sending_msg.compute_and_update_pi_agg()

            for child_id in child_ids:
                logging.debug("child: %s", str(child_id))
                new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id)
                logging.debug("new_pi_msg: %s", np.array2string(new_pi_msg))

                child_node = self.model.nodes_dict[child_id]
                child_node.update_pi_msg_from_parent(parent=node_id,
                                                     new_value=new_pi_msg)

    def _run_belief_propagation(self, evidence):
        """
        Sequentially perturb nodes with observed values, running belief propagation
        after each perturbation.

        Args
            evidence: dict,
                a dict key, value pair as {var: state_of_var observed}
        """
        for evidence_id, observed_value in evidence.items():
            if evidence_id not in self.model.nodes_dict.keys():
                raise KeyError("Evidence supplied for non-existent label_id: {}"
                               .format(evidence_id))

            if is_kronecker_delta(observed_value):
                # specific evidence
                self.model.nodes_dict[evidence_id].update_lambda_agg(observed_value)
            else:
                # virtual evidence
                self.model.nodes_dict[evidence_id].update_lambda_agg(
                    self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value
                )
            nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)]
            self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence)

    def query(self, evidence={}):
        """
        Run belief propagation given 0 or more pieces of evidence.

        Args
            evidence: dict,
                a dict key, value pair as {var: state_of_var observed},
                e.g. {'3': np.array([0,1])} if label '3' is True.

        Returns
            a dict key, value pair as {var: belief}, where belief is an np.array of the
            marginal probability of each state of the variable given the evidence.

        Example
        -------
        >> import numpy as np
        >> from beliefs.inference.belief_propagation import BeliefPropagation
        >> from beliefs.models.belief_update_node_model import BeliefUpdateNodeModel, BernoulliOrNode
        >> edges = [('1', '3'), ('2', '3'), ('3', '5')]
        >> model = BeliefUpdateNodeModel.init_from_edges(edges, BernoulliOrNode)
        >> infer = BeliefPropagation(model)
        >> result = infer.query(evidence={'2': np.array([0, 1])})
        """
        if not self.model.all_nodes_are_fully_initialized:
            self.initialize_model()

        if evidence:
            self._run_belief_propagation(evidence)

        return {label_id: node.belief for label_id, node in self.model.nodes_dict.items()}