diff options
author | Cathy Yeh <cathy@driver.xyz> | 2017-12-12 21:28:26 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2017-12-13 18:45:03 -0800 |
commit | 10f5c49ea6767f54d59f88eb4064bb4959d14c6b (patch) | |
tree | 7ce485c2e678c593fd474537e3a125cd783c79cc /beliefs | |
parent | b3b8bb68d6d590175a07dfc4022b4903d63222e5 (diff) | |
download | beliefs-10f5c49ea6767f54d59f88eb4064bb4959d14c6b.tar.gz beliefs-10f5c49ea6767f54d59f88eb4064bb4959d14c6b.tar.bz2 beliefs-10f5c49ea6767f54d59f88eb4064bb4959d14c6b.zip |
implement explicit factor methods for compute_pi_agg and compute_lambda_msg_to_parent in Node
Diffstat (limited to 'beliefs')
-rw-r--r-- | beliefs/factors/discrete_factor.py | 7 | ||||
-rw-r--r-- | beliefs/models/belief_update_node_model.py | 37 |
2 files changed, 33 insertions, 11 deletions
diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py index da8e6bf..b75da28 100644 --- a/beliefs/factors/discrete_factor.py +++ b/beliefs/factors/discrete_factor.py @@ -86,9 +86,16 @@ class DiscreteFactor: right = copy.deepcopy(other) left.add_new_variables_from_other_factor(right) right.add_new_variables_from_other_factor(left) + print('var', left.variables) + print(left.cardinality) + print(left.values) + print('var', right.variables) + print(right.cardinality) + print(right.values) # reorder variables in right factor to match order in left source_axes = list(range(right.values.ndim)) + print('source_axes', source_axes) destination_axes = [right.variables.index(var) for var in left.variables] right.variables = [right.variables[idx] for idx in destination_axes] diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index cd8ba8c..17e98fa 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -205,25 +205,30 @@ class Node: return msgs def compute_pi_agg(self): - # TODO: implement explict factor product operation - raise NotImplementedError + if len(self.parents) == 0: + self.update_pi_agg(self.cpd.values) + else: + factors_to_multiply = [self.cpd] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + factors_to_multiply.extend(pi_msgs) + + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + self.update_pi_agg(factor_product.marginalize(self.parents).values) + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) def compute_lambda_agg(self): - if len(self.children) == 0: - return self.lambda_agg.values - else: + if len(self.children) != 0: lambda_msg_values = [ msg.values for msg in self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) ] self.update_lambda_agg(reduce(np.multiply, lambda_msg_values)) - return self.lambda_agg.values def update_pi_agg(self, new_value): - self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality)) + self.pi_agg.update_values(new_value) def update_lambda_agg(self, new_value): - self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality)) + self.lambda_agg.update_values(new_value) def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type): if key not in received_msg_dict.keys(): @@ -242,7 +247,8 @@ class Node: if new_value.shape != expected_shape: raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" .format(expected_shape, new_value.shape)) - received_msg_dict[key]._values = new_value + # received_msg_dict[key]._values = new_value + received_msg_dict[key].update_values(new_value) def update_pi_msg_from_parent(self, parent, new_value): self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs, @@ -267,8 +273,17 @@ class Node: raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.") def compute_lambda_msg_to_parent(self, parent_k): - # TODO: implement explict factor product operation - raise NotImplementedError + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + factors_to_multiply = [self.cpd] + pi_msgs_excl_k = [msg for par_id, msg in self.pi_received_msgs.items() + if par_id != parent_k] + factors_to_multiply.extend(pi_msgs_excl_k) + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + new_factor = factor_product.marginalize(list(set(self.parents) - set([parent_k]))) + lambda_msg_to_k = (self.lambda_agg * new_factor).marginalize([self.lambda_agg.variables[0]]) + return self._normalize(lambda_msg_to_k.values) @property def is_fully_initialized(self): |