From 8dc7ae89677fca16ee974a30cff8c4df53c955ce Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Sun, 3 Dec 2017 19:16:32 -0800 Subject: PR comments --- beliefs/inference/belief_propagation.py | 44 +++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 19 deletions(-) (limited to 'beliefs/inference/belief_propagation.py') diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index 02f5595..7ec648d 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -1,11 +1,17 @@ import numpy as np from collections import namedtuple +import logging -from beliefs.models.beliefupdate.Node import InvalidLambdaMsgToParent -from beliefs.models.beliefupdate.BeliefUpdateNodeModel import BeliefUpdateNodeModel +from beliefs.models.belief_update_node_model import ( + InvalidLambdaMsgToParent, + BeliefUpdateNodeModel +) from beliefs.utils.math_helper import is_kronecker_delta +logger = logging.getLogger(__name__) + + MsgPassers = namedtuple('MsgPassers', ['msg_receiver', 'msg_sender']) @@ -51,7 +57,7 @@ class BeliefPropagation: return node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop() - print("Node", node_to_update_label_id) + logging.info("Node: %s", node_to_update_label_id) node = self.model.nodes_dict[node_to_update_label_id] @@ -59,8 +65,8 @@ class BeliefPropagation: # outgoing msg from the node to update parent_ids = set(node.parents) - set([msg_sender_label_id]) child_ids = set(node.children) - set([msg_sender_label_id]) - print("parent_ids:", parent_ids) - print("child_ids:", child_ids) + logging.info("parent_ids: %s", str(parent_ids)) + logging.info("child_ids: %s", str(child_ids)) if msg_sender_label_id is not None: # update triggered by receiving a message, not pinning to evidence @@ -68,9 +74,9 @@ class BeliefPropagation: if node_to_update_label_id not in evidence: node.compute_pi_agg() - print("belief propagation pi_agg", node.pi_agg) + logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg)) node.compute_lambda_agg() - print("belief propagation lambda_agg", node.lambda_agg) + logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg)) for parent_id in parent_ids: try: @@ -114,13 +120,13 @@ class BeliefPropagation: for child in node.lambda_received_msgs.keys(): node.update_lambda_msg_from_child(child=child, new_value=ones_vector) - print("Finished initializing Lambda(x) and lambda_received_msgs per node.") + logging.info("Finished initializing Lambda(x) and lambda_received_msgs per node.") - print("Start downward sweep from nodes. Sending Pi messages only.") + logging.info("Start downward sweep from nodes. Sending Pi messages only.") topdown_order = self.model.get_topologically_sorted_nodes(reverse=False) for node_id in topdown_order: - print('label in iteration through top-down order:', node_id) + logging.info('label in iteration through top-down order: %s', str(node_id)) node_sending_msg = self.model.nodes_dict[node_id] child_ids = node_sending_msg.children @@ -129,9 +135,9 @@ class BeliefPropagation: node_sending_msg.compute_pi_agg() for child_id in child_ids: - print("child", child_id) + logging.info("child: %s", str(child_id)) new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id) - print(new_pi_msg) + logging.info("new_pi_msg: %s", np.array2string(new_pi_msg)) child_node = self.model.nodes_dict[child_id] child_node.update_pi_msg_from_parent(parent=node_id, @@ -158,10 +164,9 @@ class BeliefPropagation: self.model.nodes_dict[evidence_id].lambda_agg = \ self.model.nodes_dict[evidence_id].lambda_agg * observed_value - nodes_to_update.add(MsgPassers(msg_receiver=evidence_id, - msg_sender=None)) + nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)] - self._belief_propagation(nodes_to_update=nodes_to_update, + self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence) def query(self, evidence={}): @@ -179,12 +184,13 @@ class BeliefPropagation: Example ------- - >> from label_graph_service.pgm.inference.belief_propagation import BeliefPropagation - >> from label_graph_service.pgm.models.BernoulliOrModel import BernoulliOrModel + >> import numpy as np + >> from beliefs.inference.belief_propagation import BeliefPropagation + >> from beliefs.models.belief_update_node_model import BeliefUpdateNodeModel, BernoulliOrNode >> edges = [('1', '3'), ('2', '3'), ('3', '5')] - >> model = BernoulliOrModel(edges) + >> model = BeliefUpdateNodeModel.init_from_edges(edges, BernoulliOrNode) >> infer = BeliefPropagation(model) - >> result = infer.query({'2': np.array([0, 1])}) + >> result = infer.query(evidence={'2': np.array([0, 1])}) """ if not self.model.all_nodes_are_fully_initialized: self.initialize_model() -- cgit v1.2.3