diff options
Diffstat (limited to 'beliefs/inference/belief_propagation.py')
-rw-r--r-- | beliefs/inference/belief_propagation.py | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py new file mode 100644 index 0000000..7ec648d --- /dev/null +++ b/beliefs/inference/belief_propagation.py @@ -0,0 +1,201 @@ +import numpy as np +from collections import namedtuple +import logging + +from beliefs.models.belief_update_node_model import ( + InvalidLambdaMsgToParent, + BeliefUpdateNodeModel +) +from beliefs.utils.math_helper import is_kronecker_delta + + +logger = logging.getLogger(__name__) + + +MsgPassers = namedtuple('MsgPassers', ['msg_receiver', 'msg_sender']) + + +class ConflictingEvidenceError(Exception): + """Failed to run belief propagation on label graph because of conflicting evidence.""" + def __init__(self, evidence): + message = ( + "Can't run belief propagation with conflicting evidence: {}" + .format(evidence) + ) + super().__init__(message) + + +class BeliefPropagation: + def __init__(self, model, inplace=True): + """ + Input: + model: an instance of BeliefUpdateNodeModel + inplace: bool + modify in-place the nodes in the model during belief propagation + """ + if not isinstance(model, BeliefUpdateNodeModel): + raise TypeError("Model must be an instance of BeliefUpdateNodeModel") + if inplace is False: + self.model = model.copy() + else: + self.model = model + + def _belief_propagation(self, nodes_to_update, evidence): + """ + Implementation of Pearl's belief propagation algorithm for polytrees. + + ref: "Fusion, Propagation, and Structuring in Belief Networks" + Artificial Intelligence 29 (1986) 241-288 + + Input: + nodes_to_update: list + list of MsgPasser namedtuples. + evidence: dict, + a dict key, value pair as {var: state_of_var observed} + """ + if len(nodes_to_update) == 0: + return + + node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop() + logging.info("Node: %s", node_to_update_label_id) + + node = self.model.nodes_dict[node_to_update_label_id] + + # exclude the message sender (either a parent or child) from getting an + # outgoing msg from the node to update + parent_ids = set(node.parents) - set([msg_sender_label_id]) + child_ids = set(node.children) - set([msg_sender_label_id]) + logging.info("parent_ids: %s", str(parent_ids)) + logging.info("child_ids: %s", str(child_ids)) + + if msg_sender_label_id is not None: + # update triggered by receiving a message, not pinning to evidence + assert len(node.parents) + len(node.children) - 1 == len(parent_ids) + len(child_ids) + + if node_to_update_label_id not in evidence: + node.compute_pi_agg() + logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg)) + node.compute_lambda_agg() + logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg)) + + for parent_id in parent_ids: + try: + new_lambda_msg = node.compute_lambda_msg_to_parent(parent_k=parent_id) + except InvalidLambdaMsgToParent: + raise ConflictingEvidenceError(evidence=evidence) + + parent_node = self.model.nodes_dict[parent_id] + parent_node.update_lambda_msg_from_child(child=node_to_update_label_id, + new_value=new_lambda_msg) + nodes_to_update.add(MsgPassers(msg_receiver=parent_id, + msg_sender=node_to_update_label_id)) + + for child_id in child_ids: + new_pi_msg = node.compute_pi_msg_to_child(child_k=child_id) + child_node = self.model.nodes_dict[child_id] + child_node.update_pi_msg_from_parent(parent=node_to_update_label_id, + new_value=new_pi_msg) + nodes_to_update.add(MsgPassers(msg_receiver=child_id, + msg_sender=node_to_update_label_id)) + + self._belief_propagation(nodes_to_update, evidence) + + def initialize_model(self): + """ + Apply boundary conditions: + - Set pi_agg equal to prior probabilities for root nodes. + - Set lambda_agg equal to vector of ones for leaf nodes. + + - Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as + actually passing lambda messages up from leaf nodes to root nodes). + - Calculate pi_agg and pi_received_msgs for all nodes without evidence. + (Without evidence, belief equals pi_agg.) + """ + self.model.set_boundary_conditions() + + for node in self.model.nodes_dict.values(): + ones_vector = np.ones([node.cardinality]) + + node.lambda_agg = ones_vector + for child in node.lambda_received_msgs.keys(): + node.update_lambda_msg_from_child(child=child, + new_value=ones_vector) + logging.info("Finished initializing Lambda(x) and lambda_received_msgs per node.") + + logging.info("Start downward sweep from nodes. Sending Pi messages only.") + topdown_order = self.model.get_topologically_sorted_nodes(reverse=False) + + for node_id in topdown_order: + logging.info('label in iteration through top-down order: %s', str(node_id)) + + node_sending_msg = self.model.nodes_dict[node_id] + child_ids = node_sending_msg.children + + if node_sending_msg.pi_agg is None: + node_sending_msg.compute_pi_agg() + + for child_id in child_ids: + logging.info("child: %s", str(child_id)) + new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id) + logging.info("new_pi_msg: %s", np.array2string(new_pi_msg)) + + child_node = self.model.nodes_dict[child_id] + child_node.update_pi_msg_from_parent(parent=node_id, + new_value=new_pi_msg) + + def _run_belief_propagation(self, evidence): + """ + Input: + evidence: dict + a dict key, value pair as {var: state_of_var observed} + """ + for evidence_id, observed_value in evidence.items(): + nodes_to_update = set() + + if evidence_id not in self.model.nodes_dict.keys(): + raise KeyError("Evidence supplied for non-existent label_id: {}" + .format(evidence_id)) + + if is_kronecker_delta(observed_value): + # specific evidence + self.model.nodes_dict[evidence_id].lambda_agg = observed_value + else: + # virtual evidence + self.model.nodes_dict[evidence_id].lambda_agg = \ + self.model.nodes_dict[evidence_id].lambda_agg * observed_value + + nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)] + + self._belief_propagation(nodes_to_update=set(nodes_to_update), + evidence=evidence) + + def query(self, evidence={}): + """ + Run belief propagation given evidence. + + Input: + evidence: dict + a dict key, value pair as {var: state_of_var observed}, + e.g. {'3': np.array([0,1])} if label '3' is True. + + Returns: + beliefs: dict + a dict key, value pair as {var: belief} + + Example + ------- + >> import numpy as np + >> from beliefs.inference.belief_propagation import BeliefPropagation + >> from beliefs.models.belief_update_node_model import BeliefUpdateNodeModel, BernoulliOrNode + >> edges = [('1', '3'), ('2', '3'), ('3', '5')] + >> model = BeliefUpdateNodeModel.init_from_edges(edges, BernoulliOrNode) + >> infer = BeliefPropagation(model) + >> result = infer.query(evidence={'2': np.array([0, 1])}) + """ + if not self.model.all_nodes_are_fully_initialized: + self.initialize_model() + + if evidence: + self._run_belief_propagation(evidence) + + return {label_id: node.belief for label_id, node in self.model.nodes_dict.items()} |