diff options
Diffstat (limited to 'beliefs/models/belief_update_node_model.py')
-rw-r--r-- | beliefs/models/belief_update_node_model.py | 393 |
1 files changed, 253 insertions, 140 deletions
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 1c3ba6e..ec329ca 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -7,6 +7,7 @@ from functools import reduce import networkx as nx from beliefs.models.base_models import BayesianModel +from beliefs.factors.discrete_factor import DiscreteFactor from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD @@ -32,9 +33,9 @@ class BeliefUpdateNodeModel(BayesianModel): """ def __init__(self, nodes_dict): """ - Input: - nodes_dict: dict - a dict key, value pair as {label_id: instance_of_node_class_or_subclass} + Args + nodes_dict: dict + a dict key, value pair as {label_id: instance_of_node_class_or_subclass} """ super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()), variables=list(nodes_dict.keys()), @@ -44,12 +45,15 @@ class BeliefUpdateNodeModel(BayesianModel): @classmethod def init_from_edges(cls, edges, node_class): - """Create nodes from the same node class. + """ + Create model from edges where all nodes are a from the same node class. - Input: - edges: list of edge tuples of form ('parent', 'child') - node_class: the Node class or subclass from which to - create all the nodes from edges. + Args + edges: list, + list of edge tuples of form [('parent', 'child')] + node_class: Node class or subclass, + class from which to create all the nodes automatically from edges, + e.g. BernoulliAndNode or BernoulliOrNode """ nodes = set() g = nx.DiGraph(edges) @@ -67,10 +71,12 @@ class BeliefUpdateNodeModel(BayesianModel): """ Return list of all directed edges in nodes. - Args: - nodes: an iterable of objects of the Node class or subclass - Returns: - edges: list of edge tuples + Args + nodes: iterable, + iterable of objects of the Node class or subclass + Returns + edges: list, + list of edge tuples """ edges = set() for node in nodes: @@ -81,23 +87,28 @@ class BeliefUpdateNodeModel(BayesianModel): def set_boundary_conditions(self): """ - 1. Root nodes: if x is a node with no parents, set Pi(x) = prior - probability of x. + Set boundary conditions for nodes in the model. + + 1. Root nodes: if x is a node with no parents, set Pi(x) = prior + probability of x. - 2. Leaf nodes: if x is a node with no children, set Lambda(x) - to an (unnormalized) unit vector, of length the cardinality of x. + 2. Leaf nodes: if x is a node with no children, set Lambda(x) + to an (unnormalized) unit vector, of length the cardinality of x. """ for root in self.get_roots(): - self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values + self.nodes_dict[root].update_pi_agg(self.nodes_dict[root].cpd.values) for leaf in self.get_leaves(): - self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality]) + self.nodes_dict[leaf].update_lambda_agg(np.ones([self.nodes_dict[leaf].cardinality])) @property def all_nodes_are_fully_initialized(self): """ - Returns True if, for all nodes in the model, all lambda and pi - messages and lambda_agg and pi_agg are not None, else False. + Check if all nodes in the model are initialized, i.e. lambda and pi messages and + lambda_agg and pi_agg are not None for every node. + + Returns + bool, True if all nodes in the model are initialized, else False. """ for node in self.nodes_dict.values(): if not node.is_fully_initialized: @@ -105,62 +116,53 @@ class BeliefUpdateNodeModel(BayesianModel): return True def copy(self): - """ - Returns a copy of the model. - """ + """Return a copy of the model.""" copy_nodes = copy.deepcopy(self.nodes_dict) copy_model = self.__class__(nodes_dict=copy_nodes) return copy_model class Node: - """A node in a DAG with methods to compute the belief (marginal probability - of the node given evidence) and compute pi/lambda messages to/from its neighbors + """ + A node in a DAG with methods to compute the belief (marginal probability of + the node given evidence) and compute pi/lambda messages to/from its neighbors to update its belief. - Implemented from Pearl's belief propagation algorithm. + Implemented from Pearl's belief propagation algorithm for polytrees. """ - def __init__(self, - label_id, - children, - parents, - cardinality, - cpd): + def __init__(self, children, cpd): """ Args - label_id: str - children: set of strings - parents: set of strings - cardinality: int, cardinality of the random variable the node represents - cpd: an instance of a conditional probability distribution, - e.g. BernoulliOrCPD or TabularCPD - """ - self.label_id = label_id + children: list, + list of strings + cpd: an instance of TabularCPD or one of its subclasses, + e.g. BernoulliOrCPD or BernoulliAndCPD + """ + self.label_id = cpd.variable self.children = children - self.parents = parents - self.cardinality = cardinality + self.parents = cpd.parents + self.cardinality = cpd.cardinality[0] self.cpd = cpd - self.pi_agg = None # np.array dimensions [1, cardinality] - self.lambda_agg = None # np.array dimensions [1, cardinality] + self.pi_agg = self._init_factors_for_variables([self.label_id])[self.label_id] + self.lambda_agg = self._init_factors_for_variables([self.label_id])[self.label_id] - self.pi_received_msgs = self._init_received_msgs(parents) - self.lambda_received_msgs = self._init_received_msgs(children) + self.pi_received_msgs = self._init_factors_for_variables(self.parents) + self.lambda_received_msgs = \ + {child: self._init_factors_for_variables([self.label_id])[self.label_id] + for child in children} - @classmethod - def from_cpd_class(cls, - label_id, - children, - parents, - cardinality, - cpd_class): - cpd = cpd_class(label_id, parents) - return cls(label_id, children, parents, cardinality, cpd) @property def belief(self): - if self.pi_agg.any() and self.lambda_agg.any(): - belief = np.multiply(self.pi_agg, self.lambda_agg) + """ + Calculate the marginal probability of the variable from its aggregate values. + + Returns + belief, an np.array of ndim 1 and shape (self.cardinality,) + """ + if any(self.pi_agg.values) and any(self.lambda_agg.values): + belief = (self.lambda_agg * self.pi_agg).values return self._normalize(belief) else: return None @@ -168,23 +170,48 @@ class Node: def _normalize(self, value): return value/value.sum() - @staticmethod - def _init_received_msgs(keys): - return {k: None for k in keys} + def _init_factors_for_variables(self, variables): + """ + Args + variables: list, + list of ints/strings, e.g. the single node variable or list + of parent ids of the node + Returns + factors: dict, + where the dict has key, value pair as {variable_id: instance of a DiscreteFactor}, + where DiscreteFactor.values is an np.array of ndim 1 and + shape (cardinality of variable_id,) + """ + variables = list(variables) + factors = {} + + for var in variables: + if self.cpd.state_names is not None: + state_names = {var: self.cpd.state_names[var]} + else: + state_names = None + + cardinality = self.cpd.cardinality[self.cpd.variables.index(var)] + factors[var] = DiscreteFactor(variables=[var], + cardinality=[cardinality], + values=None, + state_names=state_names) + return factors def _return_msgs_received_for_msg_type(self, message_type): """ - Input: - message_type: MessageType enum - - Returns: - msg_values: list of message values (each an np.array) + Args + message_type: MessageType enum + Returns + msg_values: list, + list of DiscreteFactors with property `values` containing + the values of the messages (np.arrays) """ if message_type == MessageType.LAMBDA: - msg_values = [msg for msg in self.lambda_received_msgs.values()] + msgs = [msg for msg in self.lambda_received_msgs.values()] elif message_type == MessageType.PI: - msg_values = [msg for msg in self.pi_received_msgs.values()] - return msg_values + msgs = [msg for msg in self.pi_received_msgs.values()] + return msgs def validate_and_return_msgs_received_for_msg_type(self, message_type): """ @@ -192,35 +219,58 @@ class Node: Raise error if all messages have not been received. Called before calculating lambda_agg (pi_agg). - Input: - message_type: MessageType enum - - Returns: - msg_values: list of message values (each an np.array) + Args + message_type: MessageType enum + Returns + msgs: list, + list of DiscreteFactors with property `values` containing + the values of the messages (np.arrays) """ - msg_values = self._return_msgs_received_for_msg_type(message_type) + msgs = self._return_msgs_received_for_msg_type(message_type) - if any(msg is None for msg in msg_values): + if any(msg.values is None for msg in msgs): raise ValueError( "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg." .format(msg_type=message_type.value) ) else: - return msg_values - - def compute_pi_agg(self): - # TODO: implement explict factor product operation - raise NotImplementedError + return msgs - def compute_lambda_agg(self): - if len(self.children) == 0: - return self.lambda_agg + def compute_and_update_pi_agg(self): + """ + Compute and update pi_agg, the prior probability, given the current state + of messages received from parents. + """ + if len(self.parents) == 0: + self.update_pi_agg(self.cpd.values) else: - lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) - self.lambda_agg = reduce(np.multiply, lambda_msg_values) - return self.lambda_agg + factors_to_multiply = [self.cpd] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + factors_to_multiply.extend(pi_msgs) + + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + self.update_pi_agg(factor_product.marginalize(self.parents).values) + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + + def compute_and_update_lambda_agg(self): + """ + Compute and update lambda_agg, the likelihood, given the current state + of messages received from children. + """ + if len(self.children) != 0: + lambda_msg_values = [ + msg.values for msg in + self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) + ] + self.update_lambda_agg(reduce(np.multiply, lambda_msg_values)) + + def update_pi_agg(self, new_value): + self.pi_agg.update_values(new_value) - def _update_received_msg_by_key(self, received_msg_dict, key, new_value): + def update_lambda_agg(self, new_value): + self.lambda_agg.update_values(new_value) + + def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type): if key not in received_msg_dict.keys(): raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}" .format(key, received_msg_dict.keys())) @@ -229,23 +279,39 @@ class Node: raise TypeError("Expected a new value of type numpy.ndarray, but got type {}" .format(type(new_value))) - if new_value.shape != (self.cardinality,): + if message_type == MessageType.LAMBDA: + expected_shape = (self.cardinality,) + elif message_type == MessageType.PI: + expected_shape = (self.cpd.cardinality[self.cpd.variables.index(key)],) + + if new_value.shape != expected_shape: raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" - .format(self.cardinality, new_value.shape)) - received_msg_dict[key] = new_value + .format(expected_shape, new_value.shape)) + received_msg_dict[key].update_values(new_value) def update_pi_msg_from_parent(self, parent, new_value): self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs, key=parent, - new_value=new_value) + new_value=new_value, + message_type=MessageType.PI) def update_lambda_msg_from_child(self, child, new_value): self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs, key=child, - new_value=new_value) + new_value=new_value, + message_type=MessageType.LAMBDA) def compute_pi_msg_to_child(self, child_k): - lambda_msg_from_child = self.lambda_received_msgs[child_k] + """ + Compute pi_msg to child. + + Args + child_k: string or int, + the label_id of the child receiving the pi_msg + Returns + np.array of ndim 1 and shape (self.cardinality,) + """ + lambda_msg_from_child = self.lambda_received_msgs[child_k].values if lambda_msg_from_child is not None: with np.errstate(divide='ignore', invalid='ignore'): # 0/0 := 0 @@ -255,8 +321,26 @@ class Node: raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.") def compute_lambda_msg_to_parent(self, parent_k): - # TODO: implement explict factor product operation - raise NotImplementedError + """ + Compute lambda_msg to parent. + + Args + parent_k: string or int, + the label_id of the parent receiving the lambda_msg + Returns + np.array of ndim 1 and shape (cardinality of parent_k,) + """ + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + factors_to_multiply = [self.cpd] + pi_msgs_excl_k = [msg for par_id, msg in self.pi_received_msgs.items() + if par_id != parent_k] + factors_to_multiply.extend(pi_msgs_excl_k) + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + new_factor = factor_product.marginalize(list(set(self.parents) - set([parent_k]))) + lambda_msg_to_k = (self.lambda_agg * new_factor).marginalize([self.lambda_agg.variables[0]]) + return self._normalize(lambda_msg_to_k.values) @property def is_fully_initialized(self): @@ -272,46 +356,60 @@ class Node: if any(msg is None for msg in pi_msgs): return False - if (self.pi_agg is None) or (self.lambda_agg is None): + if (self.pi_agg.values is None) or (self.lambda_agg.values is None): return False return True class BernoulliOrNode(Node): - def __init__(self, - label_id, - children, - parents): - super().__init__(label_id=label_id, - children=children, - parents=parents, - cardinality=2, - cpd=BernoulliOrCPD(label_id, parents)) - - def compute_pi_agg(self): + """ + A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True'] + and conditional probability distribution described by 'Or' logic. + """ + def __init__(self, label_id, children, parents): + super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) + + def compute_and_update_pi_agg(self): + """ + Compute and update pi_agg, the prior probability, given the current state + of messages received from parents. Sidestep explicit factor product and + marginalization. + """ if len(self.parents) == 0: - self.pi_agg = self.cpd.values + self.update_pi_agg(self.cpd.values) else: - pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - parents_p0 = [p[0] for p in pi_msg_values] - # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) - # of p = [P(False), P(True)] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p0 = [p.get_value_for_state_vector({p.variables[0]: 'False'}) + for p in pi_msgs] p_0 = reduce(lambda x, y: x*y, parents_p0) p_1 = 1 - p_0 - self.pi_agg = np.array([p_0, p_1]) - return self.pi_agg + self.update_pi_agg(np.array([p_0, p_1])) def compute_lambda_msg_to_parent(self, parent_k): - if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + """ + Compute lambda_msg to parent. Sidestep explicit factor product and + marginalization. + + Args + parent_k: string or int, + the label_id of the parent receiving the lambda_msg + Returns + np.array of ndim 1 and shape (cardinality of parent_k,) + """ + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p0_excluding_k = [p[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p0_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'False'}) + for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) - lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product - lambda_1 = self.lambda_agg[1] + + lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'}) + lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'}) + lambda_0 = lambda_agg_1 + (lambda_agg_0 - lambda_agg_1)*p0_product + lambda_1 = lambda_agg_1 lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent @@ -319,39 +417,54 @@ class BernoulliOrNode(Node): class BernoulliAndNode(Node): - def __init__(self, - label_id, - children, - parents): - super().__init__(label_id=label_id, - children=children, - parents=parents, - cardinality=2, - cpd=BernoulliAndCPD(label_id, parents)) - - def compute_pi_agg(self): + """ + A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True'] + and conditional probability distribution described by 'And' logic. + """ + def __init__(self, label_id, children, parents): + super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) + + def compute_and_update_pi_agg(self): + """ + Compute and update pi_agg, the prior probability, given the current state + of messages received from parents. Sidestep explicit factor product and + marginalization. + """ if len(self.parents) == 0: - self.pi_agg = self.cpd.values + self.update_pi_agg(self.cpd.values) else: - pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - parents_p1 = [p[1] for p in pi_msg_values] - # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) - # of p = [P(False), P(True)] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p1 = [p.get_value_for_state_vector({p.variables[0]: 'True'}) + for p in pi_msgs] p_1 = reduce(lambda x, y: x*y, parents_p1) p_0 = 1 - p_1 - self.pi_agg = np.array([p_0, p_1]) - return self.pi_agg + self.update_pi_agg(np.array([p_0, p_1])) def compute_lambda_msg_to_parent(self, parent_k): - if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + """ + Compute lambda_msg to parent. Sidestep explicit factor product and + marginalization. + + Args + parent_k: string or int, + the label_id of the parent receiving the lambda_msg + Returns + np.array of ndim 1 and shape (cardinality of parent_k,) + """ + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p1_excluding_k = [p[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p1_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'True'}) + for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1) - lambda_0 = self.lambda_agg[0] - lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product + + lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'}) + lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'}) + + lambda_0 = lambda_agg_0 + lambda_1 = lambda_agg_0 + (lambda_agg_1 - lambda_agg_0)*p1_product lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent |