diff options
author | Cathy Yeh <cathy@driver.xyz> | 2017-11-13 15:34:33 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2017-11-17 13:48:32 -0800 |
commit | b16e990b7e4d00e427d4445ba38eef0fb967963a (patch) | |
tree | 5ecb0137fde4779c0158eaf66f89a8df6288d250 /beliefs/inference/belief_propagation.py | |
parent | 77d8b323d4f6e05ca97d9cbef43ac85fd8040d61 (diff) | |
download | beliefs-b16e990b7e4d00e427d4445ba38eef0fb967963a.tar.gz beliefs-b16e990b7e4d00e427d4445ba38eef0fb967963a.tar.bz2 beliefs-b16e990b7e4d00e427d4445ba38eef0fb967963a.zip |
changes to work with bump from networkx 1.11 to 2.0
some nx functions now return iterators
Diffstat (limited to 'beliefs/inference/belief_propagation.py')
-rw-r--r-- | beliefs/inference/belief_propagation.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index 972fd5d..ecd5e9c 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -50,7 +50,7 @@ class BeliefPropagation: node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop() print("Node", node_to_update_label_id) - node = self.model.nodes[node_to_update_label_id] + node = self.model.nodes_dict[node_to_update_label_id] # exclude the message sender (either a parent or child) from getting an # outgoing msg from the node to update @@ -75,7 +75,7 @@ class BeliefPropagation: except InvalidLambdaMsgToParent: raise ConflictingEvidenceError(evidence=evidence) - parent_node = self.model.nodes[parent_id] + parent_node = self.model.nodes_dict[parent_id] parent_node.update_lambda_msg_from_child(child=node_to_update_label_id, new_value=new_lambda_msg) nodes_to_update.add(MsgPassers(msg_receiver=parent_id, @@ -83,7 +83,7 @@ class BeliefPropagation: for child_id in child_ids: new_pi_msg = node.compute_pi_msg_to_child(child_k=child_id) - child_node = self.model.nodes[child_id] + child_node = self.model.nodes_dict[child_id] child_node.update_pi_msg_from_parent(parent=node_to_update_label_id, new_value=new_pi_msg) nodes_to_update.add(MsgPassers(msg_receiver=child_id, @@ -104,7 +104,7 @@ class BeliefPropagation: """ self.model.set_boundary_conditions() - for node in self.model.nodes.values(): + for node in self.model.nodes_dict.values(): ones_vector = np.ones([node.cardinality]) node.lambda_agg = ones_vector @@ -119,7 +119,7 @@ class BeliefPropagation: for node_id in topdown_order: print('label in iteration through top-down order:', node_id) - node_sending_msg = self.model.nodes[node_id] + node_sending_msg = self.model.nodes_dict[node_id] child_ids = node_sending_msg.children if node_sending_msg.pi_agg is None: @@ -130,7 +130,7 @@ class BeliefPropagation: new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id) print(new_pi_msg) - child_node = self.model.nodes[child_id] + child_node = self.model.nodes_dict[child_id] child_node.update_pi_msg_from_parent(parent=node_id, new_value=new_pi_msg) @@ -143,17 +143,17 @@ class BeliefPropagation: for evidence_id, observed_value in evidence.items(): nodes_to_update = set() - if evidence_id not in self.model.nodes.keys(): + if evidence_id not in self.model.nodes_dict.keys(): raise KeyError("Evidence supplied for non-existent label_id: {}" .format(evidence_id)) if is_kronecker_delta(observed_value): # specific evidence - self.model.nodes[evidence_id].lambda_agg = observed_value + self.model.nodes_dict[evidence_id].lambda_agg = observed_value else: # virtual evidence - self.model.nodes[evidence_id].lambda_agg = \ - self.model.nodes[evidence_id].lambda_agg * observed_value + self.model.nodes_dict[evidence_id].lambda_agg = \ + self.model.nodes_dict[evidence_id].lambda_agg * observed_value nodes_to_update.add(MsgPassers(msg_receiver=evidence_id, msg_sender=None)) @@ -189,4 +189,4 @@ class BeliefPropagation: if evidence: self._run_belief_propagation(evidence) - return {label_id: node.belief for label_id, node in self.model.nodes.items()} + return {label_id: node.belief for label_id, node in self.model.nodes_dict.items()} |