From b3b8bb68d6d590175a07dfc4022b4903d63222e5 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 19:58:12 -0800 Subject: Bernoulli Or/And Node access msg values by state names --- beliefs/factors/bernoulli_and_cpd.py | 2 +- beliefs/factors/bernoulli_or_cpd.py | 2 +- beliefs/models/belief_update_node_model.py | 91 ++++++++++++++++++++++-------- 3 files changed, 69 insertions(+), 26 deletions(-) diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py index adf5ed5..15802c2 100644 --- a/beliefs/factors/bernoulli_and_cpd.py +++ b/beliefs/factors/bernoulli_and_cpd.py @@ -20,7 +20,7 @@ class BernoulliAndCPD(TabularCPD): variable_card=2, parents=parents, parents_card=[2]*len(parents), - values=[], + values=None, state_names={var: ['False', 'True'] for var in [variable] + parents}) self._values = None diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py index 6e01cf9..5b661a1 100644 --- a/beliefs/factors/bernoulli_or_cpd.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -20,7 +20,7 @@ class BernoulliOrCPD(TabularCPD): variable_card=2, parents=parents, parents_card=[2]*len(parents), - values=[], + values=None, state_names={var: ['False', 'True'] for var in [variable] + parents}) self._values = None diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 820ee0c..cd8ba8c 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -174,13 +174,13 @@ class Node: message_type: MessageType enum Returns: - msg_values: list of message values (each an np.array) + msg_values: list of DiscreteFactors containing message values (np.arrays) """ if message_type == MessageType.LAMBDA: - msg_values = [msg.values for msg in self.lambda_received_msgs.values()] + msgs = [msg for msg in self.lambda_received_msgs.values()] elif message_type == MessageType.PI: - msg_values = [msg.values for msg in self.pi_received_msgs.values()] - return msg_values + msgs = [msg for msg in self.pi_received_msgs.values()] + return msgs def validate_and_return_msgs_received_for_msg_type(self, message_type): """ @@ -192,17 +192,17 @@ class Node: message_type: MessageType enum Returns: - msg_values: list of message values (each an np.array) + msgs: list of DiscreteFactors containing message values (np.array) """ - msg_values = self._return_msgs_received_for_msg_type(message_type) + msgs = self._return_msgs_received_for_msg_type(message_type) - if any(msg is None for msg in msg_values): + if any(msg.values is None for msg in msgs): raise ValueError( "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg." .format(msg_type=message_type.value) ) else: - return msg_values + return msgs def compute_pi_agg(self): # TODO: implement explict factor product operation @@ -212,8 +212,10 @@ class Node: if len(self.children) == 0: return self.lambda_agg.values else: - lambda_msg_values =\ - self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) + lambda_msg_values = [ + msg.values for msg in + self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) + ] self.update_lambda_agg(reduce(np.multiply, lambda_msg_values)) return self.lambda_agg.values @@ -295,14 +297,30 @@ class BernoulliOrNode(Node): parents): super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) + def _init_aggregate_values(self): + variable = self.cpd.variable + return DiscreteFactor(variables=[self.cpd.variable], + cardinality=[self.cardinality], + values=None, + state_names={variable: self.cpd.state_names[variable]}) + + def _init_pi_received_msgs(self, parents): + msgs = {} + for k in parents: + kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] + msgs[k] = DiscreteFactor(variables=[k], + cardinality=[kth_cardinality], + values=None, + state_names={k: self.cpd.state_names[k]}) + return msgs + def compute_pi_agg(self): if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: - pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - parents_p0 = [p[0] for p in pi_msg_values] - # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) - # of p = [P(False), P(True)] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p0 = [p.get_value_for_state_vector({p.variables[0]: 'False'}) + for p in pi_msgs] p_0 = reduce(lambda x, y: x*y, parents_p0) p_1 = 1 - p_0 self.update_pi_agg(np.array([p_0, p_1])) @@ -314,10 +332,14 @@ class BernoulliOrNode(Node): else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p0_excluding_k = [p.values[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p0_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'False'}) + for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) - lambda_0 = self.lambda_agg.values[1] + (self.lambda_agg.values[0] - self.lambda_agg.values[1])*p0_product - lambda_1 = self.lambda_agg.values[1] + + lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'}) + lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'}) + lambda_0 = lambda_agg_1 + (lambda_agg_0 - lambda_agg_1)*p0_product + lambda_1 = lambda_agg_1 lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent @@ -331,14 +353,30 @@ class BernoulliAndNode(Node): parents): super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) + def _init_aggregate_values(self): + variable = self.cpd.variable + return DiscreteFactor(variables=[self.cpd.variable], + cardinality=[self.cardinality], + values=None, + state_names={variable: self.cpd.state_names[variable]}) + + def _init_pi_received_msgs(self, parents): + msgs = {} + for k in parents: + kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] + msgs[k] = DiscreteFactor(variables=[k], + cardinality=[kth_cardinality], + values=None, + state_names={k: self.cpd.state_names[k]}) + return msgs + def compute_pi_agg(self): if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: - pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - parents_p1 = [p[1] for p in pi_msg_values] - # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) - # of p = [P(False), P(True)] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p1 = [p.get_value_for_state_vector({p.variables[0]: 'True'}) + for p in pi_msgs] p_1 = reduce(lambda x, y: x*y, parents_p1) p_0 = 1 - p_1 self.update_pi_agg(np.array([p_0, p_1])) @@ -350,10 +388,15 @@ class BernoulliAndNode(Node): else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p1_excluding_k = [p.values[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p1_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'True'}) + for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1) - lambda_0 = self.lambda_agg.values[0] - lambda_1 = self.lambda_agg.values[0] + (self.lambda_agg.values[1] - self.lambda_agg.values[0])*p1_product + + lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'}) + lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'}) + + lambda_0 = lambda_agg_0 + lambda_1 = lambda_agg_0 + (lambda_agg_1 - lambda_agg_0)*p1_product lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent -- cgit v1.2.3