From 4373157138e85d2dbad9672cef5963a27a3d962c Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Fri, 8 Dec 2017 16:00:52 -0800 Subject: BernoulliAndNode with custom msg passing methods --- beliefs/models/belief_update_node_model.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 667e0f1..4747530 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -8,6 +8,7 @@ import networkx as nx from beliefs.models.base_models import BayesianModel from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD +from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD class InvalidLambdaMsgToParent(Exception): @@ -313,3 +314,41 @@ class BernoulliOrNode(Node): if not any(lambda_msg): raise InvalidLambdaMsgToParent return self._normalize(lambda_msg) + + +class BernoulliAndNode(Node): + def __init__(self, + label_id, + children, + parents): + super().__init__(label_id=label_id, + children=children, + parents=parents, + cardinality=2, + cpd=BernoulliAndCPD(label_id, parents)) + + def compute_pi_agg(self): + if not self.parents: + self.pi_agg = self.cpd.values + else: + pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p1 = [p[1] for p in pi_msg_values] + p_1 = reduce(lambda x, y: x*y, parents_p1) + p_0 = 1 - p_1 + self.pi_agg = np.array([p_0, p_1]) + return self.pi_agg + + def compute_lambda_msg_to_parent(self, parent_k): + if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + # TODO: cleanup this validation + _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + p1_excluding_k = [msg[1] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k] + p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1) + lambda_0 = self.lambda_agg[0] + lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product + lambda_msg = np.array([lambda_0, lambda_1]) + if not any(lambda_msg): + raise InvalidLambdaMsgToParent + return self._normalize(lambda_msg) -- cgit v1.2.3