import numpy as np from collections import namedtuple import logging from beliefs.models.belief_update_node_model import ( InvalidLambdaMsgToParent, BeliefUpdateNodeModel ) from beliefs.utils.math_helper import is_kronecker_delta logger = logging.getLogger(__name__) MsgPassers = namedtuple('MsgPassers', ['msg_receiver', 'msg_sender']) class ConflictingEvidenceError(Exception): """Failed to run belief propagation on label graph because of conflicting evidence.""" def __init__(self, evidence): message = ( "Can't run belief propagation with conflicting evidence: {}" .format(evidence) ) super().__init__(message) class BeliefPropagation: def __init__(self, model, inplace=True): """ Args model: an instance of BeliefUpdateNodeModel inplace: bool, modify in-place the nodes in the model during belief propagation """ if not isinstance(model, BeliefUpdateNodeModel): raise TypeError("Model must be an instance of BeliefUpdateNodeModel") if inplace is False: self.model = model.copy() else: self.model = model def _belief_propagation(self, nodes_to_update, evidence): """ Implementation of Pearl's belief propagation algorithm for polytrees. ref: "Fusion, Propagation, and Structuring in Belief Networks" Artificial Intelligence 29 (1986) 241-288 Args nodes_to_update: list, list of MsgPasser namedtuples. evidence: dict, a dict key, value pair as {var: state_of_var observed} """ if len(nodes_to_update) == 0: return node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop() logging.debug("Node: %s", node_to_update_label_id) node = self.model.nodes_dict[node_to_update_label_id] # exclude the message sender (either a parent or child) from getting an # outgoing msg from the node to update parent_ids = set(node.parents) - set([msg_sender_label_id]) child_ids = set(node.children) - set([msg_sender_label_id]) logging.debug("parent_ids: %s", str(parent_ids)) logging.debug("child_ids: %s", str(child_ids)) if msg_sender_label_id is not None: # update triggered by receiving a message, not pinning to evidence assert len(node.parents) + len(node.children) - 1 == len(parent_ids) + len(child_ids) if node_to_update_label_id not in evidence: node.compute_pi_agg() logging.debug("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values)) node.compute_lambda_agg() logging.debug("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values)) for parent_id in parent_ids: try: new_lambda_msg = node.compute_lambda_msg_to_parent(parent_k=parent_id) except InvalidLambdaMsgToParent: raise ConflictingEvidenceError(evidence=evidence) parent_node = self.model.nodes_dict[parent_id] parent_node.update_lambda_msg_from_child(child=node_to_update_label_id, new_value=new_lambda_msg) nodes_to_update.add(MsgPassers(msg_receiver=parent_id, msg_sender=node_to_update_label_id)) for child_id in child_ids: new_pi_msg = node.compute_pi_msg_to_child(child_k=child_id) child_node = self.model.nodes_dict[child_id] child_node.update_pi_msg_from_parent(parent=node_to_update_label_id, new_value=new_pi_msg) nodes_to_update.add(MsgPassers(msg_receiver=child_id, msg_sender=node_to_update_label_id)) self._belief_propagation(nodes_to_update, evidence) def initialize_model(self): """ 1. Apply boundary conditions: - Set pi_agg equal to prior probabilities for root nodes. - Set lambda_agg equal to vector of ones for leaf nodes. 2. Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as actually passing lambda messages up from leaf nodes to root nodes). 3. Calculate pi_agg and pi_received_msgs for all nodes without evidence. (Without evidence, belief equals pi_agg.) """ self.model.set_boundary_conditions() for node in self.model.nodes_dict.values(): ones_vector = np.ones([node.cardinality]) node.update_lambda_agg(ones_vector) for child in node.lambda_received_msgs.keys(): node.update_lambda_msg_from_child(child=child, new_value=ones_vector) logging.debug("Finished initializing Lambda(x) and lambda_received_msgs per node.") logging.debug("Start downward sweep from nodes. Sending Pi messages only.") topdown_order = self.model.get_topologically_sorted_nodes(reverse=False) for node_id in topdown_order: logging.debug('label in iteration through top-down order: %s', str(node_id)) node_sending_msg = self.model.nodes_dict[node_id] child_ids = node_sending_msg.children if node_sending_msg.pi_agg.values is None: node_sending_msg.compute_pi_agg() for child_id in child_ids: logging.debug("child: %s", str(child_id)) new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id) logging.debug("new_pi_msg: %s", np.array2string(new_pi_msg)) child_node = self.model.nodes_dict[child_id] child_node.update_pi_msg_from_parent(parent=node_id, new_value=new_pi_msg) def _run_belief_propagation(self, evidence): """ Sequentially perturb nodes with observed values, running belief propagation after each perturbation. Args evidence: dict, a dict key, value pair as {var: state_of_var observed} """ for evidence_id, observed_value in evidence.items(): if evidence_id not in self.model.nodes_dict.keys(): raise KeyError("Evidence supplied for non-existent label_id: {}" .format(evidence_id)) if is_kronecker_delta(observed_value): # specific evidence self.model.nodes_dict[evidence_id].update_lambda_agg(observed_value) else: # virtual evidence self.model.nodes_dict[evidence_id].update_lambda_agg( self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value ) nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)] self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence) def query(self, evidence={}): """ Run belief propagation given 0 or more pieces of evidence. Args evidence: dict, a dict key, value pair as {var: state_of_var observed}, e.g. {'3': np.array([0,1])} if label '3' is True. Returns a dict key, value pair as {var: belief}, where belief is an np.array of the marginal probability of each state of the variable given the evidence. Example ------- >> import numpy as np >> from beliefs.inference.belief_propagation import BeliefPropagation >> from beliefs.models.belief_update_node_model import BeliefUpdateNodeModel, BernoulliOrNode >> edges = [('1', '3'), ('2', '3'), ('3', '5')] >> model = BeliefUpdateNodeModel.init_from_edges(edges, BernoulliOrNode) >> infer = BeliefPropagation(model) >> result = infer.query(evidence={'2': np.array([0, 1])}) """ if not self.model.all_nodes_are_fully_initialized: self.initialize_model() if evidence: self._run_belief_propagation(evidence) return {label_id: node.belief for label_id, node in self.model.nodes_dict.items()}