diff options
Diffstat (limited to 'beliefs/models/belief_update_node_model.py')
-rw-r--r-- | beliefs/models/belief_update_node_model.py | 238 |
1 files changed, 160 insertions, 78 deletions
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 17e98fa..1a9ab19 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -33,9 +33,9 @@ class BeliefUpdateNodeModel(BayesianModel): """ def __init__(self, nodes_dict): """ - Input: - nodes_dict: dict - a dict key, value pair as {label_id: instance_of_node_class_or_subclass} + Args + nodes_dict: dict + a dict key, value pair as {label_id: instance_of_node_class_or_subclass} """ super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()), variables=list(nodes_dict.keys()), @@ -45,12 +45,15 @@ class BeliefUpdateNodeModel(BayesianModel): @classmethod def init_from_edges(cls, edges, node_class): - """Create nodes from the same node class. + """ + Create model from edges where all nodes are a from the same node class. - Input: - edges: list of edge tuples of form ('parent', 'child') - node_class: the Node class or subclass from which to - create all the nodes from edges. + Args + edges: list, + list of edge tuples of form [('parent', 'child')] + node_class: Node class or subclass, + class from which to create all the nodes automatically from edges, + e.g. BernoulliAndNode or BernoulliOrNode """ nodes = set() g = nx.DiGraph(edges) @@ -68,10 +71,12 @@ class BeliefUpdateNodeModel(BayesianModel): """ Return list of all directed edges in nodes. - Args: - nodes: an iterable of objects of the Node class or subclass - Returns: - edges: list of edge tuples + Args + nodes: iterable, + iterable of objects of the Node class or subclass + Returns + edges: list, + list of edge tuples """ edges = set() for node in nodes: @@ -82,11 +87,13 @@ class BeliefUpdateNodeModel(BayesianModel): def set_boundary_conditions(self): """ - 1. Root nodes: if x is a node with no parents, set Pi(x) = prior - probability of x. + Set boundary conditions for nodes in the model. + + 1. Root nodes: if x is a node with no parents, set Pi(x) = prior + probability of x. - 2. Leaf nodes: if x is a node with no children, set Lambda(x) - to an (unnormalized) unit vector, of length the cardinality of x. + 2. Leaf nodes: if x is a node with no children, set Lambda(x) + to an (unnormalized) unit vector, of length the cardinality of x. """ for root in self.get_roots(): self.nodes_dict[root].update_pi_agg(self.nodes_dict[root].cpd.values) @@ -97,8 +104,11 @@ class BeliefUpdateNodeModel(BayesianModel): @property def all_nodes_are_fully_initialized(self): """ - Returns True if, for all nodes in the model, all lambda and pi - messages and lambda_agg and pi_agg are not None, else False. + Check if all nodes in the model are initialized, i.e. lambda and pi messages and + lambda_agg and pi_agg are not None for every node. + + Returns + bool, True if all nodes in the model are initialized, else False. """ for node in self.nodes_dict.values(): if not node.is_fully_initialized: @@ -106,27 +116,27 @@ class BeliefUpdateNodeModel(BayesianModel): return True def copy(self): - """ - Returns a copy of the model. - """ + """Return a copy of the model.""" copy_nodes = copy.deepcopy(self.nodes_dict) copy_model = self.__class__(nodes_dict=copy_nodes) return copy_model class Node: - """A node in a DAG with methods to compute the belief (marginal probability - of the node given evidence) and compute pi/lambda messages to/from its neighbors + """ + A node in a DAG with methods to compute the belief (marginal probability of + the node given evidence) and compute pi/lambda messages to/from its neighbors to update its belief. - Implemented from Pearl's belief propagation algorithm. + Implemented from Pearl's belief propagation algorithm for polytrees. """ def __init__(self, children, cpd): """ Args - children: list of strings - cpd: an instance of a conditional probability distribution, - e.g. BernoulliOrCPD or TabularCPD + children: list, + list of strings + cpd: an instance of TabularCPD or one of its subclasses, + e.g. BernoulliOrCPD or BernoulliAndCPD """ self.label_id = cpd.variable self.children = children @@ -134,15 +144,20 @@ class Node: self.cardinality = cpd.cardinality[0] self.cpd = cpd - # instances of DiscreteFactor with `values` an np.array of dimensions [1, cardinality] - self.pi_agg = self._init_aggregate_values() - self.lambda_agg = self._init_aggregate_values() + self.pi_agg = self._init_factor_for_variable() + self.lambda_agg = self._init_factor_for_variable() self.pi_received_msgs = self._init_pi_received_msgs(self.parents) - self.lambda_received_msgs = {child: self._init_aggregate_values() for child in children} + self.lambda_received_msgs = {child: self._init_factor_for_variable() for child in children} @property def belief(self): + """ + Calculate the marginal probability of the variable from its aggregate values. + + Returns + belief, an np.array of ndim 1 and shape (self.cardinality,) + """ if any(self.pi_agg.values) and any(self.lambda_agg.values): belief = (self.lambda_agg * self.pi_agg).values return self._normalize(belief) @@ -152,29 +167,50 @@ class Node: def _normalize(self, value): return value/value.sum() - def _init_aggregate_values(self): + def _init_factor_for_variable(self): + """ + Returns + instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of + ndim 1 and shape (self.cardinality,) + """ return DiscreteFactor(variables=[self.cpd.variable], cardinality=[self.cardinality], values=None, state_names=None) def _init_pi_received_msgs(self, parents): + """ + Args + parents: list, + list of strings, parent ids of the node + Returns + msgs: dict, + a dict with key, value pair as {parent_id: instance of a DiscreteFactor}, + where DiscreteFactor.values is an np.array of ndim 1 and + shape (cardinality of parent_id,) + """ msgs = {} for k in parents: + if self.cpd.state_names is not None: + state_names = {k: self.cpd.state_names[k]} + else: + state_names = None + kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] msgs[k] = DiscreteFactor(variables=[k], cardinality=[kth_cardinality], values=None, - state_names=None) + state_names=state_names) return msgs def _return_msgs_received_for_msg_type(self, message_type): """ - Input: - message_type: MessageType enum - - Returns: - msg_values: list of DiscreteFactors containing message values (np.arrays) + Args + message_type: MessageType enum + Returns + msg_values: list, + list of DiscreteFactors with property `values` containing + the values of the messages (np.arrays) """ if message_type == MessageType.LAMBDA: msgs = [msg for msg in self.lambda_received_msgs.values()] @@ -188,11 +224,12 @@ class Node: Raise error if all messages have not been received. Called before calculating lambda_agg (pi_agg). - Input: - message_type: MessageType enum - - Returns: - msgs: list of DiscreteFactors containing message values (np.array) + Args + message_type: MessageType enum + Returns + msgs: list, + list of DiscreteFactors with property `values` containing + the values of the messages (np.arrays) """ msgs = self._return_msgs_received_for_msg_type(message_type) @@ -205,6 +242,10 @@ class Node: return msgs def compute_pi_agg(self): + """ + Compute and update pi_agg, the prior probability, given the current state + of messages received from parents. + """ if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: @@ -217,6 +258,10 @@ class Node: pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) def compute_lambda_agg(self): + """ + Compute and update lambda_agg, the likelihood, given the current state + of messages received from children. + """ if len(self.children) != 0: lambda_msg_values = [ msg.values for msg in @@ -245,9 +290,8 @@ class Node: expected_shape = (self.cpd.cardinality[self.cpd.variables.index(key)],) if new_value.shape != expected_shape: - raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" - .format(expected_shape, new_value.shape)) - # received_msg_dict[key]._values = new_value + raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" + .format(expected_shape, new_value.shape)) received_msg_dict[key].update_values(new_value) def update_pi_msg_from_parent(self, parent, new_value): @@ -263,6 +307,15 @@ class Node: message_type=MessageType.LAMBDA) def compute_pi_msg_to_child(self, child_k): + """ + Compute pi_msg to child. + + Args + child_k: string or int, + the label_id of the child receiving the pi_msg + Returns + np.array of ndim 1 and shape (self.cardinality,) + """ lambda_msg_from_child = self.lambda_received_msgs[child_k].values if lambda_msg_from_child is not None: with np.errstate(divide='ignore', invalid='ignore'): @@ -273,6 +326,15 @@ class Node: raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.") def compute_lambda_msg_to_parent(self, parent_k): + """ + Compute lambda_msg to parent. + + Args + parent_k: string or int, + the label_id of the parent receiving the lambda_msg + Returns + np.array of ndim 1 and shape (cardinality of parent_k,) + """ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: @@ -306,30 +368,31 @@ class Node: class BernoulliOrNode(Node): - def __init__(self, - label_id, - children, - parents): + """ + A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True'] + and conditional probability distribution described by 'Or' logic. + """ + def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) - def _init_aggregate_values(self): + def _init_factor_for_variable(self): + """ + Returns + instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of + ndim 1 and shape (self.cardinality,) + """ variable = self.cpd.variable return DiscreteFactor(variables=[self.cpd.variable], cardinality=[self.cardinality], values=None, state_names={variable: self.cpd.state_names[variable]}) - def _init_pi_received_msgs(self, parents): - msgs = {} - for k in parents: - kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] - msgs[k] = DiscreteFactor(variables=[k], - cardinality=[kth_cardinality], - values=None, - state_names={k: self.cpd.state_names[k]}) - return msgs - def compute_pi_agg(self): + """ + Compute and update pi_agg, the prior probability, given the current state + of messages received from parents. Sidestep explicit factor product and + marginalization. + """ if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: @@ -339,9 +402,18 @@ class BernoulliOrNode(Node): p_0 = reduce(lambda x, y: x*y, parents_p0) p_1 = 1 - p_0 self.update_pi_agg(np.array([p_0, p_1])) - return self.pi_agg def compute_lambda_msg_to_parent(self, parent_k): + """ + Compute lambda_msg to parent. Sidestep explicit factor product and + marginalization. + + Args + parent_k: string or int, + the label_id of the parent receiving the lambda_msg + Returns + np.array of ndim 1 and shape (cardinality of parent_k,) + """ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: @@ -362,30 +434,31 @@ class BernoulliOrNode(Node): class BernoulliAndNode(Node): - def __init__(self, - label_id, - children, - parents): + """ + A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True'] + and conditional probability distribution described by 'And' logic. + """ + def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) - def _init_aggregate_values(self): + def _init_factor_for_variable(self): + """ + Returns + instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of + ndim 1 and shape (self.cardinality,) + """ variable = self.cpd.variable return DiscreteFactor(variables=[self.cpd.variable], cardinality=[self.cardinality], values=None, state_names={variable: self.cpd.state_names[variable]}) - def _init_pi_received_msgs(self, parents): - msgs = {} - for k in parents: - kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] - msgs[k] = DiscreteFactor(variables=[k], - cardinality=[kth_cardinality], - values=None, - state_names={k: self.cpd.state_names[k]}) - return msgs - def compute_pi_agg(self): + """ + Compute and update pi_agg, the prior probability, given the current state + of messages received from parents. Sidestep explicit factor product and + marginalization. + """ if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: @@ -395,9 +468,18 @@ class BernoulliAndNode(Node): p_1 = reduce(lambda x, y: x*y, parents_p1) p_0 = 1 - p_1 self.update_pi_agg(np.array([p_0, p_1])) - return self.pi_agg def compute_lambda_msg_to_parent(self, parent_k): + """ + Compute lambda_msg to parent. Sidestep explicit factor product and + marginalization. + + Args + parent_k: string or int, + the label_id of the parent receiving the lambda_msg + Returns + np.array of ndim 1 and shape (cardinality of parent_k,) + """ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: |