diff options
Diffstat (limited to 'beliefs/models/belief_update_node_model.py')
-rw-r--r-- | beliefs/models/belief_update_node_model.py | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index cd8ba8c..17e98fa 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -205,25 +205,30 @@ class Node: return msgs def compute_pi_agg(self): - # TODO: implement explict factor product operation - raise NotImplementedError + if len(self.parents) == 0: + self.update_pi_agg(self.cpd.values) + else: + factors_to_multiply = [self.cpd] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + factors_to_multiply.extend(pi_msgs) + + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + self.update_pi_agg(factor_product.marginalize(self.parents).values) + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) def compute_lambda_agg(self): - if len(self.children) == 0: - return self.lambda_agg.values - else: + if len(self.children) != 0: lambda_msg_values = [ msg.values for msg in self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) ] self.update_lambda_agg(reduce(np.multiply, lambda_msg_values)) - return self.lambda_agg.values def update_pi_agg(self, new_value): - self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality)) + self.pi_agg.update_values(new_value) def update_lambda_agg(self, new_value): - self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality)) + self.lambda_agg.update_values(new_value) def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type): if key not in received_msg_dict.keys(): @@ -242,7 +247,8 @@ class Node: if new_value.shape != expected_shape: raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" .format(expected_shape, new_value.shape)) - received_msg_dict[key]._values = new_value + # received_msg_dict[key]._values = new_value + received_msg_dict[key].update_values(new_value) def update_pi_msg_from_parent(self, parent, new_value): self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs, @@ -267,8 +273,17 @@ class Node: raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.") def compute_lambda_msg_to_parent(self, parent_k): - # TODO: implement explict factor product operation - raise NotImplementedError + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + factors_to_multiply = [self.cpd] + pi_msgs_excl_k = [msg for par_id, msg in self.pi_received_msgs.items() + if par_id != parent_k] + factors_to_multiply.extend(pi_msgs_excl_k) + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + new_factor = factor_product.marginalize(list(set(self.parents) - set([parent_k]))) + lambda_msg_to_k = (self.lambda_agg * new_factor).marginalize([self.lambda_agg.variables[0]]) + return self._normalize(lambda_msg_to_k.values) @property def is_fully_initialized(self): |