aboutsummaryrefslogtreecommitdiff
path: root/beliefs/inference/belief_propagation.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/inference/belief_propagation.py')
-rw-r--r--beliefs/inference/belief_propagation.py22
1 files changed, 11 insertions, 11 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index 972fd5d..ecd5e9c 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -50,7 +50,7 @@ class BeliefPropagation:
node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop()
print("Node", node_to_update_label_id)
- node = self.model.nodes[node_to_update_label_id]
+ node = self.model.nodes_dict[node_to_update_label_id]
# exclude the message sender (either a parent or child) from getting an
# outgoing msg from the node to update
@@ -75,7 +75,7 @@ class BeliefPropagation:
except InvalidLambdaMsgToParent:
raise ConflictingEvidenceError(evidence=evidence)
- parent_node = self.model.nodes[parent_id]
+ parent_node = self.model.nodes_dict[parent_id]
parent_node.update_lambda_msg_from_child(child=node_to_update_label_id,
new_value=new_lambda_msg)
nodes_to_update.add(MsgPassers(msg_receiver=parent_id,
@@ -83,7 +83,7 @@ class BeliefPropagation:
for child_id in child_ids:
new_pi_msg = node.compute_pi_msg_to_child(child_k=child_id)
- child_node = self.model.nodes[child_id]
+ child_node = self.model.nodes_dict[child_id]
child_node.update_pi_msg_from_parent(parent=node_to_update_label_id,
new_value=new_pi_msg)
nodes_to_update.add(MsgPassers(msg_receiver=child_id,
@@ -104,7 +104,7 @@ class BeliefPropagation:
"""
self.model.set_boundary_conditions()
- for node in self.model.nodes.values():
+ for node in self.model.nodes_dict.values():
ones_vector = np.ones([node.cardinality])
node.lambda_agg = ones_vector
@@ -119,7 +119,7 @@ class BeliefPropagation:
for node_id in topdown_order:
print('label in iteration through top-down order:', node_id)
- node_sending_msg = self.model.nodes[node_id]
+ node_sending_msg = self.model.nodes_dict[node_id]
child_ids = node_sending_msg.children
if node_sending_msg.pi_agg is None:
@@ -130,7 +130,7 @@ class BeliefPropagation:
new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id)
print(new_pi_msg)
- child_node = self.model.nodes[child_id]
+ child_node = self.model.nodes_dict[child_id]
child_node.update_pi_msg_from_parent(parent=node_id,
new_value=new_pi_msg)
@@ -143,17 +143,17 @@ class BeliefPropagation:
for evidence_id, observed_value in evidence.items():
nodes_to_update = set()
- if evidence_id not in self.model.nodes.keys():
+ if evidence_id not in self.model.nodes_dict.keys():
raise KeyError("Evidence supplied for non-existent label_id: {}"
.format(evidence_id))
if is_kronecker_delta(observed_value):
# specific evidence
- self.model.nodes[evidence_id].lambda_agg = observed_value
+ self.model.nodes_dict[evidence_id].lambda_agg = observed_value
else:
# virtual evidence
- self.model.nodes[evidence_id].lambda_agg = \
- self.model.nodes[evidence_id].lambda_agg * observed_value
+ self.model.nodes_dict[evidence_id].lambda_agg = \
+ self.model.nodes_dict[evidence_id].lambda_agg * observed_value
nodes_to_update.add(MsgPassers(msg_receiver=evidence_id,
msg_sender=None))
@@ -189,4 +189,4 @@ class BeliefPropagation:
if evidence:
self._run_belief_propagation(evidence)
- return {label_id: node.belief for label_id, node in self.model.nodes.items()}
+ return {label_id: node.belief for label_id, node in self.model.nodes_dict.items()}