From 8cdb00cdb10200e824015ece4a94485e93857352 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Fri, 8 Dec 2017 16:00:14 -0800 Subject: bernoulli AND cpd --- beliefs/factors/bernoulli_and_cpd.py | 42 ++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 beliefs/factors/bernoulli_and_cpd.py diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py new file mode 100644 index 0000000..fb86135 --- /dev/null +++ b/beliefs/factors/bernoulli_and_cpd.py @@ -0,0 +1,42 @@ +import numpy as np + +from beliefs.factors.cpd import TabularCPD + + +class BernoulliAndCPD(TabularCPD): + """CPD class for a Bernoulli random variable whose relationship to its + parents (also Bernoulli random variables) is described by AND logic. + + If all of the variable's parents are True, then the variable + is True, and False otherwise. + """ + def __init__(self, variable, parents=[]): + """ + Args: + variable: int or string + parents: optional, list of int and/or strings + """ + super().__init__(variable=variable, + variable_card=2, + parents=parents, + parents_card=[2]*len(parents), + values=[]) + self._values = [] + + @property + def values(self): + if len(self._values) == 0: + self._values = self._build_kwise_values_array(len(self.variables)) + self._values = self._values.reshape(self.cardinality) + return self._values + + @staticmethod + def _build_kwise_values_array(k): + # special case a completely independent factor, and + # return the uniform prior + if k == 1: + return np.array([0.5, 0.5]) + + return np.array( + [1.]*(2**(k-1)-1) + [0.] + [0.,]*(2**(k-1)-1) + [1.] + ) -- cgit v1.2.3 From 4373157138e85d2dbad9672cef5963a27a3d962c Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Fri, 8 Dec 2017 16:00:52 -0800 Subject: BernoulliAndNode with custom msg passing methods --- beliefs/models/belief_update_node_model.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 667e0f1..4747530 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -8,6 +8,7 @@ import networkx as nx from beliefs.models.base_models import BayesianModel from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD +from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD class InvalidLambdaMsgToParent(Exception): @@ -313,3 +314,41 @@ class BernoulliOrNode(Node): if not any(lambda_msg): raise InvalidLambdaMsgToParent return self._normalize(lambda_msg) + + +class BernoulliAndNode(Node): + def __init__(self, + label_id, + children, + parents): + super().__init__(label_id=label_id, + children=children, + parents=parents, + cardinality=2, + cpd=BernoulliAndCPD(label_id, parents)) + + def compute_pi_agg(self): + if not self.parents: + self.pi_agg = self.cpd.values + else: + pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p1 = [p[1] for p in pi_msg_values] + p_1 = reduce(lambda x, y: x*y, parents_p1) + p_0 = 1 - p_1 + self.pi_agg = np.array([p_0, p_1]) + return self.pi_agg + + def compute_lambda_msg_to_parent(self, parent_k): + if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + # TODO: cleanup this validation + _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + p1_excluding_k = [msg[1] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k] + p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1) + lambda_0 = self.lambda_agg[0] + lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product + lambda_msg = np.array([lambda_0, lambda_1]) + if not any(lambda_msg): + raise InvalidLambdaMsgToParent + return self._normalize(lambda_msg) -- cgit v1.2.3 From 06626854ca893b44c128ca333fb5623591134746 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Fri, 8 Dec 2017 16:01:37 -0800 Subject: tests for belief propagation with AND and mixed AND and OR nodes --- beliefs/factors/bernoulli_or_cpd.py | 2 +- tests/test_belief_propagation.py | 122 +++++++++++++++++++++++++++++++++--- 2 files changed, 113 insertions(+), 11 deletions(-) diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py index bfb3a95..162e156 100644 --- a/beliefs/factors/bernoulli_or_cpd.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -25,7 +25,7 @@ class BernoulliOrCPD(TabularCPD): @property def values(self): - if not any(self._values): + if len(self._values) == 0: self._values = self._build_kwise_values_array(len(self.variables)) self._values = self._values.reshape(self.cardinality) return self._values diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py index 5c5a612..7a77311 100644 --- a/tests/test_belief_propagation.py +++ b/tests/test_belief_propagation.py @@ -5,13 +5,14 @@ from pytest import approx from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError from beliefs.models.belief_update_node_model import ( BeliefUpdateNodeModel, - BernoulliOrNode + BernoulliOrNode, + BernoulliAndNode ) @pytest.fixture(scope='module') -def edges_four_nodes(): - """Edges define a polytree with 4 nodes (connected in an X-shape with the +def edges_five_nodes(): + """Edges define a polytree with 5 nodes (connected in an X-shape with the node, 'x', at the center of the X.""" edges = [('u', 'x'), ('v', 'x'), ('x', 'y'), ('x', 'z')] return edges @@ -38,8 +39,8 @@ def many_parents_edges(): @pytest.fixture(scope='function') -def four_node_model(edges_four_nodes): - return BeliefUpdateNodeModel.init_from_edges(edges_four_nodes, BernoulliOrNode) +def five_node_model(edges_five_nodes): + return BeliefUpdateNodeModel.init_from_edges(edges_five_nodes, BernoulliOrNode) @pytest.fixture(scope='function') @@ -52,12 +53,42 @@ def many_parents_model(many_parents_edges): return BeliefUpdateNodeModel.init_from_edges(many_parents_edges, BernoulliOrNode) +@pytest.fixture(scope='function') +def many_parents_and_model(many_parents_edges): + return BeliefUpdateNodeModel.init_from_edges(many_parents_edges, BernoulliAndNode) + + @pytest.fixture(scope='function') def one_node_model(): a_node = BernoulliOrNode(label_id='x', children=[], parents=[]) return BeliefUpdateNodeModel(nodes_dict={'x': a_node}) +@pytest.fixture(scope='function') +def five_node_and_model(edges_five_nodes): + return BeliefUpdateNodeModel.init_from_edges(edges_five_nodes, BernoulliAndNode) + + +@pytest.fixture(scope='function') +def mixed_cpd_model(edges_five_nodes): + """ + X-shaped 5 node model plus one more node, 'w', with edge from 'w' to 'z'. + 'z' is an AND node while all other nodes are OR nodes. + """ + u_node = BernoulliOrNode(label_id='u', children=['x'], parents=[]) + v_node = BernoulliOrNode(label_id='v', children=['x'], parents=[]) + x_node = BernoulliOrNode(label_id='x', children=['y', 'z'], parents=['u', 'v']) + y_node = BernoulliOrNode(label_id='y', children=[], parents=['x']) + z_node = BernoulliAndNode(label_id='z', children=[], parents=['x', 'w']) + w_node = BernoulliOrNode(label_id='w', children=['z'], parents=[]) + return BeliefUpdateNodeModel(nodes_dict={'u': u_node, + 'v': v_node, + 'x': x_node, + 'y': y_node, + 'z': z_node, + 'w': w_node}) + + def get_label_mapped_to_positive_belief(query_result): """Return a dictionary mapping each label_id to the probability of the label being True.""" @@ -118,26 +149,89 @@ def test_NO_evidence_one_node_model(one_node_model): #============================================================================================== -# Tests of 4-node, 4-edge model +# Tests of 5-node, 4-edge model -def test_no_evidence_four_node_model(four_node_model): +def test_no_evidence_five_node_model(five_node_model): expected = {'x': 1-0.5**2} - infer = BeliefPropagation(four_node_model) + infer = BeliefPropagation(five_node_model) query_result = infer.query(evidence={}) result = get_label_mapped_to_positive_belief(query_result) compare_dictionaries(expected, result) -def test_virtual_evidence_for_node_x_four_node_model(four_node_model): +def test_virtual_evidence_for_node_x_five_node_model(five_node_model): """Virtual evidence for node x.""" expected = {'x': 0.967741935483871, 'y': 0.967741935483871, 'z': 0.967741935483871, 'u': 0.6451612903225806, 'v': 0.6451612903225806} - infer = BeliefPropagation(four_node_model) + infer = BeliefPropagation(five_node_model) query_result = infer.query(evidence={'x': np.array([1, 10])}) result = get_label_mapped_to_positive_belief(query_result) compare_dictionaries(expected, result) +#============================================================================================== +# Tests of 5-node, 4-edge model with AND cpds + +def test_no_evidence_five_node_and_model(five_node_and_model): + expected = {'x': 0.5**2} + infer = BeliefPropagation(five_node_and_model) + query_result = infer.query(evidence={}) + result = get_label_mapped_to_positive_belief(query_result) + compare_dictionaries(expected, result) + + +def test_one_parent_false_five_node_and_model(five_node_and_model): + expected = {'x': 0} + infer = BeliefPropagation(five_node_and_model) + query_result = infer.query(evidence={'u': np.array([1,0])}) + result = get_label_mapped_to_positive_belief(query_result) + compare_dictionaries(expected, result) + + +def test_one_parent_true_five_node_and_model(five_node_and_model): + expected = {'x': 0.5} + infer = BeliefPropagation(five_node_and_model) + query_result = infer.query(evidence={'u': np.array([0,1])}) + result = get_label_mapped_to_positive_belief(query_result) + compare_dictionaries(expected, result) + + +def test_both_parents_true_five_node_and_model(five_node_and_model): + expected = {'x': 1, 'y': 1, 'z': 1} + infer = BeliefPropagation(five_node_and_model) + query_result = infer.query(evidence={'u': np.array([0,1]), 'v': np.array([0,1])}) + result = get_label_mapped_to_positive_belief(query_result) + compare_dictionaries(expected, result) + + +#============================================================================================== +# Tests of mixed cpd model (all CPDs are OR, except for one AND node with 2 parents) + + +def test_no_evidence_mixed_cpd_model(mixed_cpd_model): + expected = {'x': 1-0.5**2, 'z': 0.5*(1-0.5**2)} + infer = BeliefPropagation(mixed_cpd_model) + query_result = infer.query(evidence={}) + result = get_label_mapped_to_positive_belief(query_result) + compare_dictionaries(expected, result) + + +def test_x_false_w_true_mixed_cpd_model(mixed_cpd_model): + expected = {'u': 0, 'v': 0, 'y': 0, 'z': 0} + infer = BeliefPropagation(mixed_cpd_model) + query_result = infer.query(evidence={'x': np.array([1,0]), 'w': np.array([0,1])}) + result = get_label_mapped_to_positive_belief(query_result) + compare_dictionaries(expected, result) + + +def test_x_true_w_true_mixed_cpd_model(mixed_cpd_model): + expected = {'y': 1, 'z': 1} + infer = BeliefPropagation(mixed_cpd_model) + query_result = infer.query(evidence={'x': np.array([0,1]), 'w': np.array([0,1])}) + result = get_label_mapped_to_positive_belief(query_result) + compare_dictionaries(expected, result) + + #============================================================================================== # Tests of simple BernoulliOr polytree model @@ -253,3 +347,11 @@ def test_negative_evidence_node_62(many_parents_model): query_result = infer.query(evidence={'62': np.array([1, 0])}) result = get_label_mapped_to_positive_belief(query_result) compare_dictionaries(expected, result) + + +def test_conflicting_evidence_and_model(many_parents_and_model): + """If one of the parents of node 62 is False, then node 62 has to be False.""" + infer = BeliefPropagation(many_parents_and_model) + with pytest.raises(ConflictingEvidenceError) as err: + query_result = infer.query(evidence={'62': np.array([0, 1]), '112': np.array([1, 0])}) + assert "Can't run belief propagation with conflicting evidence" in str(err) -- cgit v1.2.3 From 00dfdd7a897b2606ceeabf5323e71d8e80a446fc Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Mon, 11 Dec 2017 11:39:04 -0800 Subject: PR comments --- beliefs/factors/bernoulli_and_cpd.py | 7 +++++-- beliefs/factors/bernoulli_or_cpd.py | 7 +++++-- beliefs/models/belief_update_node_model.py | 14 +++++++++----- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py index fb86135..fdb0c25 100644 --- a/beliefs/factors/bernoulli_and_cpd.py +++ b/beliefs/factors/bernoulli_and_cpd.py @@ -21,11 +21,11 @@ class BernoulliAndCPD(TabularCPD): parents=parents, parents_card=[2]*len(parents), values=[]) - self._values = [] + self._values = None @property def values(self): - if len(self._values) == 0: + if self._values is None: self._values = self._build_kwise_values_array(len(self.variables)) self._values = self._values.reshape(self.cardinality) return self._values @@ -37,6 +37,9 @@ class BernoulliAndCPD(TabularCPD): if k == 1: return np.array([0.5, 0.5]) + # values are stored as a row vector using an ordering such that + # the right-most variables as defined in [variable].extend(parents) + # cycle through their values the fastest. return np.array( [1.]*(2**(k-1)-1) + [0.] + [0.,]*(2**(k-1)-1) + [1.] ) diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py index 162e156..12ee2f6 100644 --- a/beliefs/factors/bernoulli_or_cpd.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -21,11 +21,11 @@ class BernoulliOrCPD(TabularCPD): parents=parents, parents_card=[2]*len(parents), values=[]) - self._values = [] + self._values = None @property def values(self): - if len(self._values) == 0: + if self._values is None: self._values = self._build_kwise_values_array(len(self.variables)) self._values = self._values.reshape(self.cardinality) return self._values @@ -37,6 +37,9 @@ class BernoulliOrCPD(TabularCPD): if k == 1: return np.array([0.5, 0.5]) + # values are stored as a row vector using an ordering such that + # the right-most variables as defined in [variable].extend(parents) + # cycle through their values the fastest. return np.array( [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1) ) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 4747530..1c3ba6e 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -213,7 +213,7 @@ class Node: raise NotImplementedError def compute_lambda_agg(self): - if not self.children: + if len(self.children) == 0: return self.lambda_agg else: lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) @@ -290,11 +290,13 @@ class BernoulliOrNode(Node): cpd=BernoulliOrCPD(label_id, parents)) def compute_pi_agg(self): - if not self.parents: + if len(self.parents) == 0: self.pi_agg = self.cpd.values else: pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) parents_p0 = [p[0] for p in pi_msg_values] + # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) + # of p = [P(False), P(True)] p_0 = reduce(lambda x, y: x*y, parents_p0) p_1 = 1 - p_0 self.pi_agg = np.array([p_0, p_1]) @@ -306,7 +308,7 @@ class BernoulliOrNode(Node): else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k] + p0_excluding_k = [p[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product lambda_1 = self.lambda_agg[1] @@ -328,11 +330,13 @@ class BernoulliAndNode(Node): cpd=BernoulliAndCPD(label_id, parents)) def compute_pi_agg(self): - if not self.parents: + if len(self.parents) == 0: self.pi_agg = self.cpd.values else: pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) parents_p1 = [p[1] for p in pi_msg_values] + # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) + # of p = [P(False), P(True)] p_1 = reduce(lambda x, y: x*y, parents_p1) p_0 = 1 - p_1 self.pi_agg = np.array([p_0, p_1]) @@ -344,7 +348,7 @@ class BernoulliAndNode(Node): else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p1_excluding_k = [msg[1] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k] + p1_excluding_k = [p[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1) lambda_0 = self.lambda_agg[0] lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product -- cgit v1.2.3 From 7b5c17c316481edbbd13815390d0b34fb50a03a6 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Mon, 11 Dec 2017 18:54:06 -0800 Subject: bump version --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 4e379d2..bcab45a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.2 +0.0.3 -- cgit v1.2.3