diff options
Diffstat (limited to 'beliefs/inference/belief_propagation.py')
-rw-r--r-- | beliefs/inference/belief_propagation.py | 20 |
1 files changed, 8 insertions, 12 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index 7ec648d..128f645 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -74,9 +74,9 @@ class BeliefPropagation: if node_to_update_label_id not in evidence: node.compute_pi_agg() - logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg)) + logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values)) node.compute_lambda_agg() - logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg)) + logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values)) for parent_id in parent_ids: try: @@ -97,7 +97,6 @@ class BeliefPropagation: new_value=new_pi_msg) nodes_to_update.add(MsgPassers(msg_receiver=child_id, msg_sender=node_to_update_label_id)) - self._belief_propagation(nodes_to_update, evidence) def initialize_model(self): @@ -115,8 +114,8 @@ class BeliefPropagation: for node in self.model.nodes_dict.values(): ones_vector = np.ones([node.cardinality]) + node.update_lambda_agg(ones_vector) - node.lambda_agg = ones_vector for child in node.lambda_received_msgs.keys(): node.update_lambda_msg_from_child(child=child, new_value=ones_vector) @@ -131,7 +130,7 @@ class BeliefPropagation: node_sending_msg = self.model.nodes_dict[node_id] child_ids = node_sending_msg.children - if node_sending_msg.pi_agg is None: + if node_sending_msg.pi_agg.values is None: node_sending_msg.compute_pi_agg() for child_id in child_ids: @@ -150,22 +149,19 @@ class BeliefPropagation: a dict key, value pair as {var: state_of_var observed} """ for evidence_id, observed_value in evidence.items(): - nodes_to_update = set() - if evidence_id not in self.model.nodes_dict.keys(): raise KeyError("Evidence supplied for non-existent label_id: {}" .format(evidence_id)) if is_kronecker_delta(observed_value): # specific evidence - self.model.nodes_dict[evidence_id].lambda_agg = observed_value + self.model.nodes_dict[evidence_id].update_lambda_agg(observed_value) else: # virtual evidence - self.model.nodes_dict[evidence_id].lambda_agg = \ - self.model.nodes_dict[evidence_id].lambda_agg * observed_value - + self.model.nodes_dict[evidence_id].update_lambda_agg( + self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value + ) nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)] - self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence) |