From 7053fefc6f9e43b1e252d1f551401a7a70b52e93 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Wed, 13 Dec 2017 18:47:32 -0800 Subject: cleanup print statements, stale comments, minor TODOs --- beliefs/inference/belief_propagation.py | 77 +++++++++++++++++---------------- 1 file changed, 39 insertions(+), 38 deletions(-) (limited to 'beliefs/inference/belief_propagation.py') diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index 128f645..acd93d4 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -28,10 +28,10 @@ class ConflictingEvidenceError(Exception): class BeliefPropagation: def __init__(self, model, inplace=True): """ - Input: - model: an instance of BeliefUpdateNodeModel - inplace: bool - modify in-place the nodes in the model during belief propagation + Args + model: an instance of BeliefUpdateNodeModel + inplace: bool, + modify in-place the nodes in the model during belief propagation """ if not isinstance(model, BeliefUpdateNodeModel): raise TypeError("Model must be an instance of BeliefUpdateNodeModel") @@ -43,21 +43,20 @@ class BeliefPropagation: def _belief_propagation(self, nodes_to_update, evidence): """ Implementation of Pearl's belief propagation algorithm for polytrees. - ref: "Fusion, Propagation, and Structuring in Belief Networks" Artificial Intelligence 29 (1986) 241-288 - Input: - nodes_to_update: list - list of MsgPasser namedtuples. - evidence: dict, - a dict key, value pair as {var: state_of_var observed} + Args + nodes_to_update: list, + list of MsgPasser namedtuples. + evidence: dict, + a dict key, value pair as {var: state_of_var observed} """ if len(nodes_to_update) == 0: return node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop() - logging.info("Node: %s", node_to_update_label_id) + logging.debug("Node: %s", node_to_update_label_id) node = self.model.nodes_dict[node_to_update_label_id] @@ -65,8 +64,8 @@ class BeliefPropagation: # outgoing msg from the node to update parent_ids = set(node.parents) - set([msg_sender_label_id]) child_ids = set(node.children) - set([msg_sender_label_id]) - logging.info("parent_ids: %s", str(parent_ids)) - logging.info("child_ids: %s", str(child_ids)) + logging.debug("parent_ids: %s", str(parent_ids)) + logging.debug("child_ids: %s", str(child_ids)) if msg_sender_label_id is not None: # update triggered by receiving a message, not pinning to evidence @@ -74,9 +73,9 @@ class BeliefPropagation: if node_to_update_label_id not in evidence: node.compute_pi_agg() - logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values)) + logging.debug("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values)) node.compute_lambda_agg() - logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values)) + logging.debug("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values)) for parent_id in parent_ids: try: @@ -101,14 +100,14 @@ class BeliefPropagation: def initialize_model(self): """ - Apply boundary conditions: + 1. Apply boundary conditions: - Set pi_agg equal to prior probabilities for root nodes. - Set lambda_agg equal to vector of ones for leaf nodes. - - Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as - actually passing lambda messages up from leaf nodes to root nodes). - - Calculate pi_agg and pi_received_msgs for all nodes without evidence. - (Without evidence, belief equals pi_agg.) + 2. Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as + actually passing lambda messages up from leaf nodes to root nodes). + 3. Calculate pi_agg and pi_received_msgs for all nodes without evidence. + (Without evidence, belief equals pi_agg.) """ self.model.set_boundary_conditions() @@ -119,13 +118,13 @@ class BeliefPropagation: for child in node.lambda_received_msgs.keys(): node.update_lambda_msg_from_child(child=child, new_value=ones_vector) - logging.info("Finished initializing Lambda(x) and lambda_received_msgs per node.") + logging.debug("Finished initializing Lambda(x) and lambda_received_msgs per node.") - logging.info("Start downward sweep from nodes. Sending Pi messages only.") + logging.debug("Start downward sweep from nodes. Sending Pi messages only.") topdown_order = self.model.get_topologically_sorted_nodes(reverse=False) for node_id in topdown_order: - logging.info('label in iteration through top-down order: %s', str(node_id)) + logging.debug('label in iteration through top-down order: %s', str(node_id)) node_sending_msg = self.model.nodes_dict[node_id] child_ids = node_sending_msg.children @@ -134,9 +133,9 @@ class BeliefPropagation: node_sending_msg.compute_pi_agg() for child_id in child_ids: - logging.info("child: %s", str(child_id)) + logging.debug("child: %s", str(child_id)) new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id) - logging.info("new_pi_msg: %s", np.array2string(new_pi_msg)) + logging.debug("new_pi_msg: %s", np.array2string(new_pi_msg)) child_node = self.model.nodes_dict[child_id] child_node.update_pi_msg_from_parent(parent=node_id, @@ -144,9 +143,12 @@ class BeliefPropagation: def _run_belief_propagation(self, evidence): """ - Input: - evidence: dict - a dict key, value pair as {var: state_of_var observed} + Sequentially perturb nodes with observed values, running belief propagation + after each perturbation. + + Args + evidence: dict, + a dict key, value pair as {var: state_of_var observed} """ for evidence_id, observed_value in evidence.items(): if evidence_id not in self.model.nodes_dict.keys(): @@ -162,21 +164,20 @@ class BeliefPropagation: self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value ) nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)] - self._belief_propagation(nodes_to_update=set(nodes_to_update), - evidence=evidence) + self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence) def query(self, evidence={}): """ - Run belief propagation given evidence. + Run belief propagation given 0 or more pieces of evidence. - Input: - evidence: dict - a dict key, value pair as {var: state_of_var observed}, - e.g. {'3': np.array([0,1])} if label '3' is True. + Args + evidence: dict, + a dict key, value pair as {var: state_of_var observed}, + e.g. {'3': np.array([0,1])} if label '3' is True. - Returns: - beliefs: dict - a dict key, value pair as {var: belief} + Returns + a dict key, value pair as {var: belief}, where belief is an np.array of the + marginal probability of each state of the variable given the evidence. Example ------- -- cgit v1.2.3