aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2018-01-24 14:35:15 -0800
committerCathy Yeh <cathy@driver.xyz>2018-01-24 14:35:15 -0800
commit99fafd1338404e927ee65ae156576a85d1200a01 (patch)
tree20637239af388298b57a0614e2fec551496696d1
parent2366e92bdb9c81bc2bd7132a00ed5c16a5160c5e (diff)
downloadbeliefs-99fafd1338404e927ee65ae156576a85d1200a01.tar.gz
beliefs-99fafd1338404e927ee65ae156576a85d1200a01.tar.bz2
beliefs-99fafd1338404e927ee65ae156576a85d1200a01.zip
add more logging to belief propagation
-rw-r--r--beliefs/inference/belief_propagation.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index e6e7b18..5b063f0 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -56,7 +56,7 @@ class BeliefPropagation:
return
node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop()
- logging.debug("Node: %s", node_to_update_label_id)
+ logger.debug("Node: %s", node_to_update_label_id)
node = self.model.nodes_dict[node_to_update_label_id]
@@ -64,8 +64,8 @@ class BeliefPropagation:
# outgoing msg from the node to update
parent_ids = set(node.parents) - set([msg_sender_label_id])
child_ids = set(node.children) - set([msg_sender_label_id])
- logging.debug("parent_ids: %s", str(parent_ids))
- logging.debug("child_ids: %s", str(child_ids))
+ logger.debug("parent_ids: %s", str(parent_ids))
+ logger.debug("child_ids: %s", str(child_ids))
if msg_sender_label_id is not None:
# update triggered by receiving a message, not pinning to evidence
@@ -73,9 +73,9 @@ class BeliefPropagation:
if node_to_update_label_id not in evidence:
node.compute_and_update_pi_agg()
- logging.debug("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values))
+ logger.debug("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values))
node.compute_and_update_lambda_agg()
- logging.debug("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values))
+ logger.debug("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values))
for parent_id in parent_ids:
try:
@@ -118,13 +118,13 @@ class BeliefPropagation:
for child in node.lambda_received_msgs.keys():
node.update_lambda_msg_from_child(child=child,
new_value=ones_vector)
- logging.debug("Finished initializing Lambda(x) and lambda_received_msgs per node.")
+ logger.debug("Finished initializing Lambda(x) and lambda_received_msgs per node.")
- logging.debug("Start downward sweep from nodes. Sending Pi messages only.")
+ logger.debug("Start downward sweep from nodes. Sending Pi messages only.")
topdown_order = self.model.get_topologically_sorted_nodes(reverse=False)
for node_id in topdown_order:
- logging.debug('label in iteration through top-down order: %s', str(node_id))
+ logger.debug('label in iteration through top-down order: %s', str(node_id))
node_sending_msg = self.model.nodes_dict[node_id]
child_ids = node_sending_msg.children
@@ -133,9 +133,9 @@ class BeliefPropagation:
node_sending_msg.compute_and_update_pi_agg()
for child_id in child_ids:
- logging.debug("child: %s", str(child_id))
+ logger.debug("child: %s", str(child_id))
new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id)
- logging.debug("new_pi_msg: %s", np.array2string(new_pi_msg))
+ logger.debug("new_pi_msg: %s", np.array2string(new_pi_msg))
child_node = self.model.nodes_dict[child_id]
child_node.update_pi_msg_from_parent(parent=node_id,
@@ -151,7 +151,10 @@ class BeliefPropagation:
a dict key, value pair as {var: state_of_var observed}
"""
for evidence_id, observed_value in evidence.items():
+ logger.info("evidence id: %s", evidence_id)
+ logger.info("evidence observed value: %s", np.array2string(observed_value))
if evidence_id not in self.model.nodes_dict.keys():
+ logger.error("Evidence supplied for non-existent label_id: %s", evidence_id)
raise KeyError("Evidence supplied for non-existent label_id: {}"
.format(evidence_id))
@@ -193,6 +196,7 @@ class BeliefPropagation:
self.initialize_model()
if evidence:
+ logger.info("Run belief propagation for %s pieces of evidence", len(evidence))
self._run_belief_propagation(evidence)
return {label_id: node.belief for label_id, node in self.model.nodes_dict.items()}