aboutsummaryrefslogtreecommitdiff
path: root/beliefs/inference/belief_propagation.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/inference/belief_propagation.py')
-rw-r--r--beliefs/inference/belief_propagation.py20
1 files changed, 8 insertions, 12 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index 7ec648d..128f645 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -74,9 +74,9 @@ class BeliefPropagation:
if node_to_update_label_id not in evidence:
node.compute_pi_agg()
- logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg))
+ logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values))
node.compute_lambda_agg()
- logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg))
+ logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values))
for parent_id in parent_ids:
try:
@@ -97,7 +97,6 @@ class BeliefPropagation:
new_value=new_pi_msg)
nodes_to_update.add(MsgPassers(msg_receiver=child_id,
msg_sender=node_to_update_label_id))
-
self._belief_propagation(nodes_to_update, evidence)
def initialize_model(self):
@@ -115,8 +114,8 @@ class BeliefPropagation:
for node in self.model.nodes_dict.values():
ones_vector = np.ones([node.cardinality])
+ node.update_lambda_agg(ones_vector)
- node.lambda_agg = ones_vector
for child in node.lambda_received_msgs.keys():
node.update_lambda_msg_from_child(child=child,
new_value=ones_vector)
@@ -131,7 +130,7 @@ class BeliefPropagation:
node_sending_msg = self.model.nodes_dict[node_id]
child_ids = node_sending_msg.children
- if node_sending_msg.pi_agg is None:
+ if node_sending_msg.pi_agg.values is None:
node_sending_msg.compute_pi_agg()
for child_id in child_ids:
@@ -150,22 +149,19 @@ class BeliefPropagation:
a dict key, value pair as {var: state_of_var observed}
"""
for evidence_id, observed_value in evidence.items():
- nodes_to_update = set()
-
if evidence_id not in self.model.nodes_dict.keys():
raise KeyError("Evidence supplied for non-existent label_id: {}"
.format(evidence_id))
if is_kronecker_delta(observed_value):
# specific evidence
- self.model.nodes_dict[evidence_id].lambda_agg = observed_value
+ self.model.nodes_dict[evidence_id].update_lambda_agg(observed_value)
else:
# virtual evidence
- self.model.nodes_dict[evidence_id].lambda_agg = \
- self.model.nodes_dict[evidence_id].lambda_agg * observed_value
-
+ self.model.nodes_dict[evidence_id].update_lambda_agg(
+ self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value
+ )
nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)]
-
self._belief_propagation(nodes_to_update=set(nodes_to_update),
evidence=evidence)