aboutsummaryrefslogtreecommitdiff
path: root/beliefs/inference/belief_propagation.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/inference/belief_propagation.py')
-rw-r--r--beliefs/inference/belief_propagation.py44
1 files changed, 25 insertions, 19 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index 02f5595..7ec648d 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -1,11 +1,17 @@
import numpy as np
from collections import namedtuple
+import logging
-from beliefs.models.beliefupdate.Node import InvalidLambdaMsgToParent
-from beliefs.models.beliefupdate.BeliefUpdateNodeModel import BeliefUpdateNodeModel
+from beliefs.models.belief_update_node_model import (
+ InvalidLambdaMsgToParent,
+ BeliefUpdateNodeModel
+)
from beliefs.utils.math_helper import is_kronecker_delta
+logger = logging.getLogger(__name__)
+
+
MsgPassers = namedtuple('MsgPassers', ['msg_receiver', 'msg_sender'])
@@ -51,7 +57,7 @@ class BeliefPropagation:
return
node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop()
- print("Node", node_to_update_label_id)
+ logging.info("Node: %s", node_to_update_label_id)
node = self.model.nodes_dict[node_to_update_label_id]
@@ -59,8 +65,8 @@ class BeliefPropagation:
# outgoing msg from the node to update
parent_ids = set(node.parents) - set([msg_sender_label_id])
child_ids = set(node.children) - set([msg_sender_label_id])
- print("parent_ids:", parent_ids)
- print("child_ids:", child_ids)
+ logging.info("parent_ids: %s", str(parent_ids))
+ logging.info("child_ids: %s", str(child_ids))
if msg_sender_label_id is not None:
# update triggered by receiving a message, not pinning to evidence
@@ -68,9 +74,9 @@ class BeliefPropagation:
if node_to_update_label_id not in evidence:
node.compute_pi_agg()
- print("belief propagation pi_agg", node.pi_agg)
+ logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg))
node.compute_lambda_agg()
- print("belief propagation lambda_agg", node.lambda_agg)
+ logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg))
for parent_id in parent_ids:
try:
@@ -114,13 +120,13 @@ class BeliefPropagation:
for child in node.lambda_received_msgs.keys():
node.update_lambda_msg_from_child(child=child,
new_value=ones_vector)
- print("Finished initializing Lambda(x) and lambda_received_msgs per node.")
+ logging.info("Finished initializing Lambda(x) and lambda_received_msgs per node.")
- print("Start downward sweep from nodes. Sending Pi messages only.")
+ logging.info("Start downward sweep from nodes. Sending Pi messages only.")
topdown_order = self.model.get_topologically_sorted_nodes(reverse=False)
for node_id in topdown_order:
- print('label in iteration through top-down order:', node_id)
+ logging.info('label in iteration through top-down order: %s', str(node_id))
node_sending_msg = self.model.nodes_dict[node_id]
child_ids = node_sending_msg.children
@@ -129,9 +135,9 @@ class BeliefPropagation:
node_sending_msg.compute_pi_agg()
for child_id in child_ids:
- print("child", child_id)
+ logging.info("child: %s", str(child_id))
new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id)
- print(new_pi_msg)
+ logging.info("new_pi_msg: %s", np.array2string(new_pi_msg))
child_node = self.model.nodes_dict[child_id]
child_node.update_pi_msg_from_parent(parent=node_id,
@@ -158,10 +164,9 @@ class BeliefPropagation:
self.model.nodes_dict[evidence_id].lambda_agg = \
self.model.nodes_dict[evidence_id].lambda_agg * observed_value
- nodes_to_update.add(MsgPassers(msg_receiver=evidence_id,
- msg_sender=None))
+ nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)]
- self._belief_propagation(nodes_to_update=nodes_to_update,
+ self._belief_propagation(nodes_to_update=set(nodes_to_update),
evidence=evidence)
def query(self, evidence={}):
@@ -179,12 +184,13 @@ class BeliefPropagation:
Example
-------
- >> from label_graph_service.pgm.inference.belief_propagation import BeliefPropagation
- >> from label_graph_service.pgm.models.BernoulliOrModel import BernoulliOrModel
+ >> import numpy as np
+ >> from beliefs.inference.belief_propagation import BeliefPropagation
+ >> from beliefs.models.belief_update_node_model import BeliefUpdateNodeModel, BernoulliOrNode
>> edges = [('1', '3'), ('2', '3'), ('3', '5')]
- >> model = BernoulliOrModel(edges)
+ >> model = BeliefUpdateNodeModel.init_from_edges(edges, BernoulliOrNode)
>> infer = BeliefPropagation(model)
- >> result = infer.query({'2': np.array([0, 1])})
+ >> result = infer.query(evidence={'2': np.array([0, 1])})
"""
if not self.model.all_nodes_are_fully_initialized:
self.initialize_model()