From 70bdf07d25f41de1a9510b64267bfa29791760c7 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 16:11:54 -0800 Subject: change all msg datatypes from np.array -> DiscreteFactor --- beliefs/inference/belief_propagation.py | 20 +++--- beliefs/models/belief_update_node_model.py | 110 ++++++++++++++++++----------- 2 files changed, 77 insertions(+), 53 deletions(-) diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index 7ec648d..128f645 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -74,9 +74,9 @@ class BeliefPropagation: if node_to_update_label_id not in evidence: node.compute_pi_agg() - logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg)) + logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values)) node.compute_lambda_agg() - logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg)) + logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values)) for parent_id in parent_ids: try: @@ -97,7 +97,6 @@ class BeliefPropagation: new_value=new_pi_msg) nodes_to_update.add(MsgPassers(msg_receiver=child_id, msg_sender=node_to_update_label_id)) - self._belief_propagation(nodes_to_update, evidence) def initialize_model(self): @@ -115,8 +114,8 @@ class BeliefPropagation: for node in self.model.nodes_dict.values(): ones_vector = np.ones([node.cardinality]) + node.update_lambda_agg(ones_vector) - node.lambda_agg = ones_vector for child in node.lambda_received_msgs.keys(): node.update_lambda_msg_from_child(child=child, new_value=ones_vector) @@ -131,7 +130,7 @@ class BeliefPropagation: node_sending_msg = self.model.nodes_dict[node_id] child_ids = node_sending_msg.children - if node_sending_msg.pi_agg is None: + if node_sending_msg.pi_agg.values is None: node_sending_msg.compute_pi_agg() for child_id in child_ids: @@ -150,22 +149,19 @@ class BeliefPropagation: a dict key, value pair as {var: state_of_var observed} """ for evidence_id, observed_value in evidence.items(): - nodes_to_update = set() - if evidence_id not in self.model.nodes_dict.keys(): raise KeyError("Evidence supplied for non-existent label_id: {}" .format(evidence_id)) if is_kronecker_delta(observed_value): # specific evidence - self.model.nodes_dict[evidence_id].lambda_agg = observed_value + self.model.nodes_dict[evidence_id].update_lambda_agg(observed_value) else: # virtual evidence - self.model.nodes_dict[evidence_id].lambda_agg = \ - self.model.nodes_dict[evidence_id].lambda_agg * observed_value - + self.model.nodes_dict[evidence_id].update_lambda_agg( + self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value + ) nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)] - self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 1c3ba6e..1765ed9 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -7,6 +7,7 @@ from functools import reduce import networkx as nx from beliefs.models.base_models import BayesianModel +from beliefs.factors.discrete_factor import DiscreteFactor from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD @@ -88,10 +89,10 @@ class BeliefUpdateNodeModel(BayesianModel): to an (unnormalized) unit vector, of length the cardinality of x. """ for root in self.get_roots(): - self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values + self.nodes_dict[root].update_pi_agg(self.nodes_dict[root].cpd.values) for leaf in self.get_leaves(): - self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality]) + self.nodes_dict[leaf].update_lambda_agg(np.ones([self.nodes_dict[leaf].cardinality])) @property def all_nodes_are_fully_initialized(self): @@ -135,17 +136,18 @@ class Node: cpd: an instance of a conditional probability distribution, e.g. BernoulliOrCPD or TabularCPD """ - self.label_id = label_id + self.label_id = label_id # this can be obtained from cpd.variable self.children = children - self.parents = parents - self.cardinality = cardinality + self.parents = parents # this can be obtained from cpd.variables[1:] + self.cardinality = cardinality # this can be obtained from cpd.cardinality[0] self.cpd = cpd - self.pi_agg = None # np.array dimensions [1, cardinality] - self.lambda_agg = None # np.array dimensions [1, cardinality] + # instances of DiscreteFactor with `values` an np.array of dimensions [1, cardinality] + self.pi_agg = self._init_aggregate_values() + self.lambda_agg = self._init_aggregate_values() - self.pi_received_msgs = self._init_received_msgs(parents) - self.lambda_received_msgs = self._init_received_msgs(children) + self.pi_received_msgs = self._init_pi_received_msgs(parents) + self.lambda_received_msgs = {child: self._init_aggregate_values() for child in children} @classmethod def from_cpd_class(cls, @@ -159,8 +161,8 @@ class Node: @property def belief(self): - if self.pi_agg.any() and self.lambda_agg.any(): - belief = np.multiply(self.pi_agg, self.lambda_agg) + if any(self.pi_agg.values) and any(self.lambda_agg.values): + belief = (self.lambda_agg * self.pi_agg).values return self._normalize(belief) else: return None @@ -168,9 +170,21 @@ class Node: def _normalize(self, value): return value/value.sum() - @staticmethod - def _init_received_msgs(keys): - return {k: None for k in keys} + def _init_aggregate_values(self): + return DiscreteFactor(variables=[self.cpd.variable], + cardinality=[self.cardinality], + values=None, + state_names=None) + + def _init_pi_received_msgs(self, parents): + msgs = {} + for k in parents: + kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] + msgs[k] = DiscreteFactor(variables=[k], + cardinality=[kth_cardinality], + values=None, + state_names=None) + return msgs def _return_msgs_received_for_msg_type(self, message_type): """ @@ -181,9 +195,9 @@ class Node: msg_values: list of message values (each an np.array) """ if message_type == MessageType.LAMBDA: - msg_values = [msg for msg in self.lambda_received_msgs.values()] + msg_values = [msg.values for msg in self.lambda_received_msgs.values()] elif message_type == MessageType.PI: - msg_values = [msg for msg in self.pi_received_msgs.values()] + msg_values = [msg.values for msg in self.pi_received_msgs.values()] return msg_values def validate_and_return_msgs_received_for_msg_type(self, message_type): @@ -214,13 +228,20 @@ class Node: def compute_lambda_agg(self): if len(self.children) == 0: - return self.lambda_agg + return self.lambda_agg.values else: - lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) - self.lambda_agg = reduce(np.multiply, lambda_msg_values) - return self.lambda_agg + lambda_msg_values =\ + self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) + self.update_lambda_agg(reduce(np.multiply, lambda_msg_values)) + return self.lambda_agg.values + + def update_pi_agg(self, new_value): + self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality)) + + def update_lambda_agg(self, new_value): + self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality)) - def _update_received_msg_by_key(self, received_msg_dict, key, new_value): + def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type): if key not in received_msg_dict.keys(): raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}" .format(key, received_msg_dict.keys())) @@ -229,23 +250,30 @@ class Node: raise TypeError("Expected a new value of type numpy.ndarray, but got type {}" .format(type(new_value))) - if new_value.shape != (self.cardinality,): - raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" - .format(self.cardinality, new_value.shape)) - received_msg_dict[key] = new_value + if message_type == MessageType.LAMBDA: + expected_shape = (self.cardinality,) + elif message_type == MessageType.PI: + expected_shape = (self.cpd.cardinality[self.cpd.variables.index(key)],) + + if new_value.shape != expected_shape: + raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" + .format(expected_shape, new_value.shape)) + received_msg_dict[key]._values = new_value def update_pi_msg_from_parent(self, parent, new_value): self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs, key=parent, - new_value=new_value) + new_value=new_value, + message_type=MessageType.PI) def update_lambda_msg_from_child(self, child, new_value): self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs, key=child, - new_value=new_value) + new_value=new_value, + message_type=MessageType.LAMBDA) def compute_pi_msg_to_child(self, child_k): - lambda_msg_from_child = self.lambda_received_msgs[child_k] + lambda_msg_from_child = self.lambda_received_msgs[child_k].values if lambda_msg_from_child is not None: with np.errstate(divide='ignore', invalid='ignore'): # 0/0 := 0 @@ -272,7 +300,7 @@ class Node: if any(msg is None for msg in pi_msgs): return False - if (self.pi_agg is None) or (self.lambda_agg is None): + if (self.pi_agg.values is None) or (self.lambda_agg.values is None): return False return True @@ -291,7 +319,7 @@ class BernoulliOrNode(Node): def compute_pi_agg(self): if len(self.parents) == 0: - self.pi_agg = self.cpd.values + self.update_pi_agg(self.cpd.values) else: pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) parents_p0 = [p[0] for p in pi_msg_values] @@ -299,19 +327,19 @@ class BernoulliOrNode(Node): # of p = [P(False), P(True)] p_0 = reduce(lambda x, y: x*y, parents_p0) p_1 = 1 - p_0 - self.pi_agg = np.array([p_0, p_1]) + self.update_pi_agg(np.array([p_0, p_1])) return self.pi_agg def compute_lambda_msg_to_parent(self, parent_k): - if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p0_excluding_k = [p[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p0_excluding_k = [p.values[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) - lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product - lambda_1 = self.lambda_agg[1] + lambda_0 = self.lambda_agg.values[1] + (self.lambda_agg.values[0] - self.lambda_agg.values[1])*p0_product + lambda_1 = self.lambda_agg.values[1] lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent @@ -331,7 +359,7 @@ class BernoulliAndNode(Node): def compute_pi_agg(self): if len(self.parents) == 0: - self.pi_agg = self.cpd.values + self.update_pi_agg(self.cpd.values) else: pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) parents_p1 = [p[1] for p in pi_msg_values] @@ -339,19 +367,19 @@ class BernoulliAndNode(Node): # of p = [P(False), P(True)] p_1 = reduce(lambda x, y: x*y, parents_p1) p_0 = 1 - p_1 - self.pi_agg = np.array([p_0, p_1]) + self.update_pi_agg(np.array([p_0, p_1])) return self.pi_agg def compute_lambda_msg_to_parent(self, parent_k): - if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p1_excluding_k = [p[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p1_excluding_k = [p.values[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1) - lambda_0 = self.lambda_agg[0] - lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product + lambda_0 = self.lambda_agg.values[0] + lambda_1 = self.lambda_agg.values[0] + (self.lambda_agg.values[1] - self.lambda_agg.values[0])*p1_product lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent -- cgit v1.2.3