From 10f5c49ea6767f54d59f88eb4064bb4959d14c6b Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 21:28:26 -0800 Subject: implement explicit factor methods for compute_pi_agg and compute_lambda_msg_to_parent in Node --- beliefs/factors/discrete_factor.py | 7 ++++ beliefs/models/belief_update_node_model.py | 37 ++++++++++++----- tests/test_belief_propagation.py | 64 +++++++++++++++++++++++++++++- 3 files changed, 96 insertions(+), 12 deletions(-) diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py index da8e6bf..b75da28 100644 --- a/beliefs/factors/discrete_factor.py +++ b/beliefs/factors/discrete_factor.py @@ -86,9 +86,16 @@ class DiscreteFactor: right = copy.deepcopy(other) left.add_new_variables_from_other_factor(right) right.add_new_variables_from_other_factor(left) + print('var', left.variables) + print(left.cardinality) + print(left.values) + print('var', right.variables) + print(right.cardinality) + print(right.values) # reorder variables in right factor to match order in left source_axes = list(range(right.values.ndim)) + print('source_axes', source_axes) destination_axes = [right.variables.index(var) for var in left.variables] right.variables = [right.variables[idx] for idx in destination_axes] diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index cd8ba8c..17e98fa 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -205,25 +205,30 @@ class Node: return msgs def compute_pi_agg(self): - # TODO: implement explict factor product operation - raise NotImplementedError + if len(self.parents) == 0: + self.update_pi_agg(self.cpd.values) + else: + factors_to_multiply = [self.cpd] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + factors_to_multiply.extend(pi_msgs) + + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + self.update_pi_agg(factor_product.marginalize(self.parents).values) + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) def compute_lambda_agg(self): - if len(self.children) == 0: - return self.lambda_agg.values - else: + if len(self.children) != 0: lambda_msg_values = [ msg.values for msg in self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) ] self.update_lambda_agg(reduce(np.multiply, lambda_msg_values)) - return self.lambda_agg.values def update_pi_agg(self, new_value): - self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality)) + self.pi_agg.update_values(new_value) def update_lambda_agg(self, new_value): - self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality)) + self.lambda_agg.update_values(new_value) def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type): if key not in received_msg_dict.keys(): @@ -242,7 +247,8 @@ class Node: if new_value.shape != expected_shape: raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" .format(expected_shape, new_value.shape)) - received_msg_dict[key]._values = new_value + # received_msg_dict[key]._values = new_value + received_msg_dict[key].update_values(new_value) def update_pi_msg_from_parent(self, parent, new_value): self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs, @@ -267,8 +273,17 @@ class Node: raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.") def compute_lambda_msg_to_parent(self, parent_k): - # TODO: implement explict factor product operation - raise NotImplementedError + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + factors_to_multiply = [self.cpd] + pi_msgs_excl_k = [msg for par_id, msg in self.pi_received_msgs.items() + if par_id != parent_k] + factors_to_multiply.extend(pi_msgs_excl_k) + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + new_factor = factor_product.marginalize(list(set(self.parents) - set([parent_k]))) + lambda_msg_to_k = (self.lambda_agg * new_factor).marginalize([self.lambda_agg.variables[0]]) + return self._normalize(lambda_msg_to_k.values) @property def is_fully_initialized(self): diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py index 7a77311..1b8c0ac 100644 --- a/tests/test_belief_propagation.py +++ b/tests/test_belief_propagation.py @@ -3,10 +3,12 @@ import pytest from pytest import approx from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError +from beliefs.factors.cpd import TabularCPD from beliefs.models.belief_update_node_model import ( BeliefUpdateNodeModel, BernoulliOrNode, - BernoulliAndNode + BernoulliAndNode, + Node ) @@ -89,6 +91,41 @@ def mixed_cpd_model(edges_five_nodes): 'w': w_node}) +@pytest.fixture(scope='function') +def custom_cpd_model(): + """ + Y-shaped model, with parents ,'u' and 'v' as Or-nodes, 'x' a node with + cardinality 3 and custom CPD, 'y' a node with cardinality 2 and custom CPD. + """ + custom_cpd_x = TabularCPD(variable='x', + variable_card=3, + parents=['u', 'v'], + parents_card=[2, 2], + values=[[0.2, 0, 0.3, 0.1], + [0.4, 1, 0.7, 0.2], + [0.4, 0, 0, 0.7]], + state_names={'x': ['lo', 'med', 'hi'], + 'u': ['False', 'True'], + 'v': ['False', 'True']}) + custom_cpd_y = TabularCPD(variable='y', + variable_card=2, + parents=['x'], + parents_card=[3], + values=[[0.3, 0.1, 0], + [0.7, 0.9, 1]], + state_names={'x': ['lo', 'med', 'hi'], + 'y': ['False', 'True']}) + + u_node = BernoulliOrNode(label_id='u', children=['x'], parents=[]) + v_node = BernoulliOrNode(label_id='v', children=['x'], parents=[]) + x_node = Node(children=['y'], cpd=custom_cpd_x) + y_node = Node(children=[], cpd=custom_cpd_y) + return BeliefUpdateNodeModel(nodes_dict={'u': u_node, + 'v': v_node, + 'x': x_node, + 'y': y_node}) + + def get_label_mapped_to_positive_belief(query_result): """Return a dictionary mapping each label_id to the probability of the label being True.""" @@ -355,3 +392,28 @@ def test_conflicting_evidence_and_model(many_parents_and_model): with pytest.raises(ConflictingEvidenceError) as err: query_result = infer.query(evidence={'62': np.array([0, 1]), '112': np.array([1, 0])}) assert "Can't run belief propagation with conflicting evidence" in str(err) + + +#============================================================================================== +# Model with two custom cpds + + +def test_no_evidence_custom_cpd_model(custom_cpd_model): + expected = {'x': np.array([0.15, 0.575, 0.275]), + 'v': np.array([0.5, 0.5]), + 'u': np.array([0.5, 0.5]), + 'y': np.array([0.1025, 0.8975])} + infer = BeliefPropagation(custom_cpd_model) + query_result = infer.query(evidence={}) + compare_dictionaries(expected, query_result) + + +def test_evidence_custom_cpd_model(custom_cpd_model): + """Custom node is observed to be in 'med' state.""" + expected = {'x': np.array([0., 1., 0.]), + 'u': np.array([0.60869565, 0.39130435]), + 'v': np.array([0.47826087, 0.52173913]), + 'y': np.array([0.1, 0.9])} + infer = BeliefPropagation(custom_cpd_model) + query_result = infer.query(evidence={'x': np.array([0, 1, 0])}) + compare_dictionaries(expected, query_result) -- cgit v1.2.3