From 1826222ed133b33f05fe0290fade25c2bde20729 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Mon, 11 Dec 2017 19:13:32 -0800 Subject: discrete factor with minimal methods --- beliefs/factors/discrete_factor.py | 121 +++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 beliefs/factors/discrete_factor.py (limited to 'beliefs') diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py new file mode 100644 index 0000000..da8e6bf --- /dev/null +++ b/beliefs/factors/discrete_factor.py @@ -0,0 +1,121 @@ +import copy +import numpy as np + + +class DiscreteFactor: + + def __init__(self, variables, cardinality, values=None, state_names=None): + """ + Args + variables: list, + variables in the scope of the factor + cardinality: list, + cardinalities of each variable, where len(cardinality)=len(variables) + values: list, + row vector of values of variables with ordering such that right-most variables + defined in `variables` cycle through their values the fastest + state_names: dictionary, + mapping variables to their states, of format {label_name: ['state1', 'state2']} + """ + self.variables = list(variables) + self.cardinality = cardinality + if values is None: + self._values = None + else: + self._values = np.array(values).reshape(self.cardinality) + self.state_names = state_names + + def __mul__(self, other): + return self.product(other) + + @property + def values(self): + return self._values + + def update_values(self, new_values): + """We make this available because _values is allowed to be None on init""" + self._values = np.array(new_values).reshape(self.cardinality) + + def get_value_for_state_vector(self, dict_of_states): + """ + Return the value for a dictionary of variable states. + + Args + dict_of_states: dictionary, + of format {label_name1: 'state1', label_name2: 'True'} + Returns + probability, a float, the factor value for a specific combination of variable states + """ + assert sorted(dict_of_states.keys()) == sorted(self.variables), \ + "The keys for the dictionary of states must match the variables in factor scope." + state_coordinates = [] + for var in self.variables: + var_state = dict_of_states[var] + idx_in_var_axis = self.state_names[var].index(var_state) + state_coordinates.append(idx_in_var_axis) + return self.values[tuple(state_coordinates)] + + def add_new_variables_from_other_factor(self, other): + """Add new variables to the factor.""" + extra_vars = set(other.variables) - set(self.variables) + # if all of these variables already exist there is nothing to do + if len(extra_vars) == 0: + return + # otherwise, extend the values array + slice_ = [slice(None)] * len(self.variables) + slice_.extend([np.newaxis] * len(extra_vars)) + self._values = self._values[slice_] + self.variables.extend(extra_vars) + + new_card_var = other.get_cardinality(extra_vars) + self.cardinality.extend([new_card_var[var] for var in extra_vars]) + return + + def get_cardinality(self, variables): + return {var: self.cardinality[self.variables.index(var)] for var in variables} + + def product(self, other): + left = copy.deepcopy(self) + + if isinstance(other, (int, float)): + # TODO: handle case of multiplication by constant + pass + else: + # assert right is a class or subclass of DiscreteFactor + # that has attributes: variables, values; method: get_cardinality + right = copy.deepcopy(other) + left.add_new_variables_from_other_factor(right) + right.add_new_variables_from_other_factor(left) + + # reorder variables in right factor to match order in left + source_axes = list(range(right.values.ndim)) + destination_axes = [right.variables.index(var) for var in left.variables] + right.variables = [right.variables[idx] for idx in destination_axes] + + # rearrange values in right factor to correspond to the reordered variables + right._values = np.moveaxis(right.values, source_axes, destination_axes) + left._values = left.values * right.values + return left + + def marginalize(self, vars): + """ + Args + vars: list, + variables over which to marginalize the factor + Returns + DiscreteFactor + """ + phi = copy.deepcopy(self) + + var_indexes = [] + for var in vars: + if var not in phi.variables: + raise ValueError('{} not in scope'.format(var)) + else: + var_indexes.append(self.variables.index(var)) + + index_to_keep = sorted(set(range(len(self.variables))) - set(var_indexes)) + phi.variables = [self.variables[index] for index in index_to_keep] + phi.cardinality = [self.cardinality[index] for index in index_to_keep] + phi._values = np.sum(phi.values, axis=tuple(var_indexes)) + return phi -- cgit v1.2.3 From f6ab3e7b918396dee70dc4ff2dc3a1341aaeb97b Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 14:12:51 -0800 Subject: TabularCPD inherits from DiscreteFactor --- beliefs/factors/cpd.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) (limited to 'beliefs') diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py index a286aaa..9e7191f 100644 --- a/beliefs/factors/cpd.py +++ b/beliefs/factors/cpd.py @@ -1,15 +1,14 @@ import numpy as np +from beliefs.factors.discrete_factor import DiscreteFactor -class TabularCPD: +class TabularCPD(DiscreteFactor): """ Defines the conditional probability table for a discrete variable whose parents are also discrete. - - TODO: have this inherit from DiscreteFactor implementing explicit factor methods """ def __init__(self, variable, variable_card, - parents=[], parents_card=[], values=[]): + parents=[], parents_card=[], values=[], state_names=None): """ Args: variable: int or string @@ -17,16 +16,15 @@ class TabularCPD: parents: optional, list of int and/or strings parents_card: optional, list of int values: optional, 2d list or array + state_names: dictionary (optional), + mapping variables to their states, of format {label_name: ['state1', 'state2']} """ + super().__init__(variables=[variable] + parents, + cardinality=[variable_card] + parents_card, + values=values, + state_names=state_names) self.variable = variable self.parents = parents - self.variables = [variable] + parents - self.cardinality = [variable_card] + parents_card - self._values = np.array(values) - - @property - def values(self): - return self._values def get_values(self): """ -- cgit v1.2.3 From 70bdf07d25f41de1a9510b64267bfa29791760c7 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 16:11:54 -0800 Subject: change all msg datatypes from np.array -> DiscreteFactor --- beliefs/inference/belief_propagation.py | 20 +++--- beliefs/models/belief_update_node_model.py | 110 ++++++++++++++++++----------- 2 files changed, 77 insertions(+), 53 deletions(-) (limited to 'beliefs') diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index 7ec648d..128f645 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -74,9 +74,9 @@ class BeliefPropagation: if node_to_update_label_id not in evidence: node.compute_pi_agg() - logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg)) + logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values)) node.compute_lambda_agg() - logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg)) + logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values)) for parent_id in parent_ids: try: @@ -97,7 +97,6 @@ class BeliefPropagation: new_value=new_pi_msg) nodes_to_update.add(MsgPassers(msg_receiver=child_id, msg_sender=node_to_update_label_id)) - self._belief_propagation(nodes_to_update, evidence) def initialize_model(self): @@ -115,8 +114,8 @@ class BeliefPropagation: for node in self.model.nodes_dict.values(): ones_vector = np.ones([node.cardinality]) + node.update_lambda_agg(ones_vector) - node.lambda_agg = ones_vector for child in node.lambda_received_msgs.keys(): node.update_lambda_msg_from_child(child=child, new_value=ones_vector) @@ -131,7 +130,7 @@ class BeliefPropagation: node_sending_msg = self.model.nodes_dict[node_id] child_ids = node_sending_msg.children - if node_sending_msg.pi_agg is None: + if node_sending_msg.pi_agg.values is None: node_sending_msg.compute_pi_agg() for child_id in child_ids: @@ -150,22 +149,19 @@ class BeliefPropagation: a dict key, value pair as {var: state_of_var observed} """ for evidence_id, observed_value in evidence.items(): - nodes_to_update = set() - if evidence_id not in self.model.nodes_dict.keys(): raise KeyError("Evidence supplied for non-existent label_id: {}" .format(evidence_id)) if is_kronecker_delta(observed_value): # specific evidence - self.model.nodes_dict[evidence_id].lambda_agg = observed_value + self.model.nodes_dict[evidence_id].update_lambda_agg(observed_value) else: # virtual evidence - self.model.nodes_dict[evidence_id].lambda_agg = \ - self.model.nodes_dict[evidence_id].lambda_agg * observed_value - + self.model.nodes_dict[evidence_id].update_lambda_agg( + self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value + ) nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)] - self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 1c3ba6e..1765ed9 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -7,6 +7,7 @@ from functools import reduce import networkx as nx from beliefs.models.base_models import BayesianModel +from beliefs.factors.discrete_factor import DiscreteFactor from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD @@ -88,10 +89,10 @@ class BeliefUpdateNodeModel(BayesianModel): to an (unnormalized) unit vector, of length the cardinality of x. """ for root in self.get_roots(): - self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values + self.nodes_dict[root].update_pi_agg(self.nodes_dict[root].cpd.values) for leaf in self.get_leaves(): - self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality]) + self.nodes_dict[leaf].update_lambda_agg(np.ones([self.nodes_dict[leaf].cardinality])) @property def all_nodes_are_fully_initialized(self): @@ -135,17 +136,18 @@ class Node: cpd: an instance of a conditional probability distribution, e.g. BernoulliOrCPD or TabularCPD """ - self.label_id = label_id + self.label_id = label_id # this can be obtained from cpd.variable self.children = children - self.parents = parents - self.cardinality = cardinality + self.parents = parents # this can be obtained from cpd.variables[1:] + self.cardinality = cardinality # this can be obtained from cpd.cardinality[0] self.cpd = cpd - self.pi_agg = None # np.array dimensions [1, cardinality] - self.lambda_agg = None # np.array dimensions [1, cardinality] + # instances of DiscreteFactor with `values` an np.array of dimensions [1, cardinality] + self.pi_agg = self._init_aggregate_values() + self.lambda_agg = self._init_aggregate_values() - self.pi_received_msgs = self._init_received_msgs(parents) - self.lambda_received_msgs = self._init_received_msgs(children) + self.pi_received_msgs = self._init_pi_received_msgs(parents) + self.lambda_received_msgs = {child: self._init_aggregate_values() for child in children} @classmethod def from_cpd_class(cls, @@ -159,8 +161,8 @@ class Node: @property def belief(self): - if self.pi_agg.any() and self.lambda_agg.any(): - belief = np.multiply(self.pi_agg, self.lambda_agg) + if any(self.pi_agg.values) and any(self.lambda_agg.values): + belief = (self.lambda_agg * self.pi_agg).values return self._normalize(belief) else: return None @@ -168,9 +170,21 @@ class Node: def _normalize(self, value): return value/value.sum() - @staticmethod - def _init_received_msgs(keys): - return {k: None for k in keys} + def _init_aggregate_values(self): + return DiscreteFactor(variables=[self.cpd.variable], + cardinality=[self.cardinality], + values=None, + state_names=None) + + def _init_pi_received_msgs(self, parents): + msgs = {} + for k in parents: + kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] + msgs[k] = DiscreteFactor(variables=[k], + cardinality=[kth_cardinality], + values=None, + state_names=None) + return msgs def _return_msgs_received_for_msg_type(self, message_type): """ @@ -181,9 +195,9 @@ class Node: msg_values: list of message values (each an np.array) """ if message_type == MessageType.LAMBDA: - msg_values = [msg for msg in self.lambda_received_msgs.values()] + msg_values = [msg.values for msg in self.lambda_received_msgs.values()] elif message_type == MessageType.PI: - msg_values = [msg for msg in self.pi_received_msgs.values()] + msg_values = [msg.values for msg in self.pi_received_msgs.values()] return msg_values def validate_and_return_msgs_received_for_msg_type(self, message_type): @@ -214,13 +228,20 @@ class Node: def compute_lambda_agg(self): if len(self.children) == 0: - return self.lambda_agg + return self.lambda_agg.values else: - lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) - self.lambda_agg = reduce(np.multiply, lambda_msg_values) - return self.lambda_agg + lambda_msg_values =\ + self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) + self.update_lambda_agg(reduce(np.multiply, lambda_msg_values)) + return self.lambda_agg.values + + def update_pi_agg(self, new_value): + self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality)) + + def update_lambda_agg(self, new_value): + self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality)) - def _update_received_msg_by_key(self, received_msg_dict, key, new_value): + def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type): if key not in received_msg_dict.keys(): raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}" .format(key, received_msg_dict.keys())) @@ -229,23 +250,30 @@ class Node: raise TypeError("Expected a new value of type numpy.ndarray, but got type {}" .format(type(new_value))) - if new_value.shape != (self.cardinality,): - raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" - .format(self.cardinality, new_value.shape)) - received_msg_dict[key] = new_value + if message_type == MessageType.LAMBDA: + expected_shape = (self.cardinality,) + elif message_type == MessageType.PI: + expected_shape = (self.cpd.cardinality[self.cpd.variables.index(key)],) + + if new_value.shape != expected_shape: + raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" + .format(expected_shape, new_value.shape)) + received_msg_dict[key]._values = new_value def update_pi_msg_from_parent(self, parent, new_value): self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs, key=parent, - new_value=new_value) + new_value=new_value, + message_type=MessageType.PI) def update_lambda_msg_from_child(self, child, new_value): self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs, key=child, - new_value=new_value) + new_value=new_value, + message_type=MessageType.LAMBDA) def compute_pi_msg_to_child(self, child_k): - lambda_msg_from_child = self.lambda_received_msgs[child_k] + lambda_msg_from_child = self.lambda_received_msgs[child_k].values if lambda_msg_from_child is not None: with np.errstate(divide='ignore', invalid='ignore'): # 0/0 := 0 @@ -272,7 +300,7 @@ class Node: if any(msg is None for msg in pi_msgs): return False - if (self.pi_agg is None) or (self.lambda_agg is None): + if (self.pi_agg.values is None) or (self.lambda_agg.values is None): return False return True @@ -291,7 +319,7 @@ class BernoulliOrNode(Node): def compute_pi_agg(self): if len(self.parents) == 0: - self.pi_agg = self.cpd.values + self.update_pi_agg(self.cpd.values) else: pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) parents_p0 = [p[0] for p in pi_msg_values] @@ -299,19 +327,19 @@ class BernoulliOrNode(Node): # of p = [P(False), P(True)] p_0 = reduce(lambda x, y: x*y, parents_p0) p_1 = 1 - p_0 - self.pi_agg = np.array([p_0, p_1]) + self.update_pi_agg(np.array([p_0, p_1])) return self.pi_agg def compute_lambda_msg_to_parent(self, parent_k): - if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p0_excluding_k = [p[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p0_excluding_k = [p.values[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) - lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product - lambda_1 = self.lambda_agg[1] + lambda_0 = self.lambda_agg.values[1] + (self.lambda_agg.values[0] - self.lambda_agg.values[1])*p0_product + lambda_1 = self.lambda_agg.values[1] lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent @@ -331,7 +359,7 @@ class BernoulliAndNode(Node): def compute_pi_agg(self): if len(self.parents) == 0: - self.pi_agg = self.cpd.values + self.update_pi_agg(self.cpd.values) else: pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) parents_p1 = [p[1] for p in pi_msg_values] @@ -339,19 +367,19 @@ class BernoulliAndNode(Node): # of p = [P(False), P(True)] p_1 = reduce(lambda x, y: x*y, parents_p1) p_0 = 1 - p_1 - self.pi_agg = np.array([p_0, p_1]) + self.update_pi_agg(np.array([p_0, p_1])) return self.pi_agg def compute_lambda_msg_to_parent(self, parent_k): - if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p1_excluding_k = [p[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p1_excluding_k = [p.values[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1) - lambda_0 = self.lambda_agg[0] - lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product + lambda_0 = self.lambda_agg.values[0] + lambda_1 = self.lambda_agg.values[0] + (self.lambda_agg.values[1] - self.lambda_agg.values[0])*p1_product lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent -- cgit v1.2.3 From 76090e3f03c01e208d41203a6286ea432714090a Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 19:06:30 -0800 Subject: clean up node class, simpler initialization --- beliefs/models/belief_update_node_model.py | 42 ++++++------------------------ 1 file changed, 8 insertions(+), 34 deletions(-) (limited to 'beliefs') diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 1765ed9..820ee0c 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -121,44 +121,26 @@ class Node: Implemented from Pearl's belief propagation algorithm. """ - def __init__(self, - label_id, - children, - parents, - cardinality, - cpd): + def __init__(self, children, cpd): """ Args - label_id: str - children: set of strings - parents: set of strings - cardinality: int, cardinality of the random variable the node represents + children: list of strings cpd: an instance of a conditional probability distribution, e.g. BernoulliOrCPD or TabularCPD """ - self.label_id = label_id # this can be obtained from cpd.variable + self.label_id = cpd.variable self.children = children - self.parents = parents # this can be obtained from cpd.variables[1:] - self.cardinality = cardinality # this can be obtained from cpd.cardinality[0] + self.parents = cpd.parents + self.cardinality = cpd.cardinality[0] self.cpd = cpd # instances of DiscreteFactor with `values` an np.array of dimensions [1, cardinality] self.pi_agg = self._init_aggregate_values() self.lambda_agg = self._init_aggregate_values() - self.pi_received_msgs = self._init_pi_received_msgs(parents) + self.pi_received_msgs = self._init_pi_received_msgs(self.parents) self.lambda_received_msgs = {child: self._init_aggregate_values() for child in children} - @classmethod - def from_cpd_class(cls, - label_id, - children, - parents, - cardinality, - cpd_class): - cpd = cpd_class(label_id, parents) - return cls(label_id, children, parents, cardinality, cpd) - @property def belief(self): if any(self.pi_agg.values) and any(self.lambda_agg.values): @@ -311,11 +293,7 @@ class BernoulliOrNode(Node): label_id, children, parents): - super().__init__(label_id=label_id, - children=children, - parents=parents, - cardinality=2, - cpd=BernoulliOrCPD(label_id, parents)) + super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) def compute_pi_agg(self): if len(self.parents) == 0: @@ -351,11 +329,7 @@ class BernoulliAndNode(Node): label_id, children, parents): - super().__init__(label_id=label_id, - children=children, - parents=parents, - cardinality=2, - cpd=BernoulliAndCPD(label_id, parents)) + super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) def compute_pi_agg(self): if len(self.parents) == 0: -- cgit v1.2.3 From 2f4de4ae0b28e0e5ee2a5be6955366267fbc2404 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 14:27:11 -0800 Subject: init Bernoulli And,Or CPDs w/ default state names 'False','True' --- beliefs/factors/bernoulli_and_cpd.py | 3 ++- beliefs/factors/bernoulli_or_cpd.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'beliefs') diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py index fdb0c25..adf5ed5 100644 --- a/beliefs/factors/bernoulli_and_cpd.py +++ b/beliefs/factors/bernoulli_and_cpd.py @@ -20,7 +20,8 @@ class BernoulliAndCPD(TabularCPD): variable_card=2, parents=parents, parents_card=[2]*len(parents), - values=[]) + values=[], + state_names={var: ['False', 'True'] for var in [variable] + parents}) self._values = None @property diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py index 12ee2f6..6e01cf9 100644 --- a/beliefs/factors/bernoulli_or_cpd.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -20,7 +20,8 @@ class BernoulliOrCPD(TabularCPD): variable_card=2, parents=parents, parents_card=[2]*len(parents), - values=[]) + values=[], + state_names={var: ['False', 'True'] for var in [variable] + parents}) self._values = None @property -- cgit v1.2.3 From b3b8bb68d6d590175a07dfc4022b4903d63222e5 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 19:58:12 -0800 Subject: Bernoulli Or/And Node access msg values by state names --- beliefs/factors/bernoulli_and_cpd.py | 2 +- beliefs/factors/bernoulli_or_cpd.py | 2 +- beliefs/models/belief_update_node_model.py | 91 ++++++++++++++++++++++-------- 3 files changed, 69 insertions(+), 26 deletions(-) (limited to 'beliefs') diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py index adf5ed5..15802c2 100644 --- a/beliefs/factors/bernoulli_and_cpd.py +++ b/beliefs/factors/bernoulli_and_cpd.py @@ -20,7 +20,7 @@ class BernoulliAndCPD(TabularCPD): variable_card=2, parents=parents, parents_card=[2]*len(parents), - values=[], + values=None, state_names={var: ['False', 'True'] for var in [variable] + parents}) self._values = None diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py index 6e01cf9..5b661a1 100644 --- a/beliefs/factors/bernoulli_or_cpd.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -20,7 +20,7 @@ class BernoulliOrCPD(TabularCPD): variable_card=2, parents=parents, parents_card=[2]*len(parents), - values=[], + values=None, state_names={var: ['False', 'True'] for var in [variable] + parents}) self._values = None diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 820ee0c..cd8ba8c 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -174,13 +174,13 @@ class Node: message_type: MessageType enum Returns: - msg_values: list of message values (each an np.array) + msg_values: list of DiscreteFactors containing message values (np.arrays) """ if message_type == MessageType.LAMBDA: - msg_values = [msg.values for msg in self.lambda_received_msgs.values()] + msgs = [msg for msg in self.lambda_received_msgs.values()] elif message_type == MessageType.PI: - msg_values = [msg.values for msg in self.pi_received_msgs.values()] - return msg_values + msgs = [msg for msg in self.pi_received_msgs.values()] + return msgs def validate_and_return_msgs_received_for_msg_type(self, message_type): """ @@ -192,17 +192,17 @@ class Node: message_type: MessageType enum Returns: - msg_values: list of message values (each an np.array) + msgs: list of DiscreteFactors containing message values (np.array) """ - msg_values = self._return_msgs_received_for_msg_type(message_type) + msgs = self._return_msgs_received_for_msg_type(message_type) - if any(msg is None for msg in msg_values): + if any(msg.values is None for msg in msgs): raise ValueError( "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg." .format(msg_type=message_type.value) ) else: - return msg_values + return msgs def compute_pi_agg(self): # TODO: implement explict factor product operation @@ -212,8 +212,10 @@ class Node: if len(self.children) == 0: return self.lambda_agg.values else: - lambda_msg_values =\ - self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) + lambda_msg_values = [ + msg.values for msg in + self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) + ] self.update_lambda_agg(reduce(np.multiply, lambda_msg_values)) return self.lambda_agg.values @@ -295,14 +297,30 @@ class BernoulliOrNode(Node): parents): super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) + def _init_aggregate_values(self): + variable = self.cpd.variable + return DiscreteFactor(variables=[self.cpd.variable], + cardinality=[self.cardinality], + values=None, + state_names={variable: self.cpd.state_names[variable]}) + + def _init_pi_received_msgs(self, parents): + msgs = {} + for k in parents: + kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] + msgs[k] = DiscreteFactor(variables=[k], + cardinality=[kth_cardinality], + values=None, + state_names={k: self.cpd.state_names[k]}) + return msgs + def compute_pi_agg(self): if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: - pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - parents_p0 = [p[0] for p in pi_msg_values] - # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) - # of p = [P(False), P(True)] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p0 = [p.get_value_for_state_vector({p.variables[0]: 'False'}) + for p in pi_msgs] p_0 = reduce(lambda x, y: x*y, parents_p0) p_1 = 1 - p_0 self.update_pi_agg(np.array([p_0, p_1])) @@ -314,10 +332,14 @@ class BernoulliOrNode(Node): else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p0_excluding_k = [p.values[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p0_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'False'}) + for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) - lambda_0 = self.lambda_agg.values[1] + (self.lambda_agg.values[0] - self.lambda_agg.values[1])*p0_product - lambda_1 = self.lambda_agg.values[1] + + lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'}) + lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'}) + lambda_0 = lambda_agg_1 + (lambda_agg_0 - lambda_agg_1)*p0_product + lambda_1 = lambda_agg_1 lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent @@ -331,14 +353,30 @@ class BernoulliAndNode(Node): parents): super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) + def _init_aggregate_values(self): + variable = self.cpd.variable + return DiscreteFactor(variables=[self.cpd.variable], + cardinality=[self.cardinality], + values=None, + state_names={variable: self.cpd.state_names[variable]}) + + def _init_pi_received_msgs(self, parents): + msgs = {} + for k in parents: + kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] + msgs[k] = DiscreteFactor(variables=[k], + cardinality=[kth_cardinality], + values=None, + state_names={k: self.cpd.state_names[k]}) + return msgs + def compute_pi_agg(self): if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: - pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - parents_p1 = [p[1] for p in pi_msg_values] - # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) - # of p = [P(False), P(True)] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p1 = [p.get_value_for_state_vector({p.variables[0]: 'True'}) + for p in pi_msgs] p_1 = reduce(lambda x, y: x*y, parents_p1) p_0 = 1 - p_1 self.update_pi_agg(np.array([p_0, p_1])) @@ -350,10 +388,15 @@ class BernoulliAndNode(Node): else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p1_excluding_k = [p.values[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p1_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'True'}) + for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1) - lambda_0 = self.lambda_agg.values[0] - lambda_1 = self.lambda_agg.values[0] + (self.lambda_agg.values[1] - self.lambda_agg.values[0])*p1_product + + lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'}) + lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'}) + + lambda_0 = lambda_agg_0 + lambda_1 = lambda_agg_0 + (lambda_agg_1 - lambda_agg_0)*p1_product lambda_msg = np.array([lambda_0, lambda_1]) if not any(lambda_msg): raise InvalidLambdaMsgToParent -- cgit v1.2.3 From 10f5c49ea6767f54d59f88eb4064bb4959d14c6b Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 21:28:26 -0800 Subject: implement explicit factor methods for compute_pi_agg and compute_lambda_msg_to_parent in Node --- beliefs/factors/discrete_factor.py | 7 ++++ beliefs/models/belief_update_node_model.py | 37 ++++++++++++----- tests/test_belief_propagation.py | 64 +++++++++++++++++++++++++++++- 3 files changed, 96 insertions(+), 12 deletions(-) (limited to 'beliefs') diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py index da8e6bf..b75da28 100644 --- a/beliefs/factors/discrete_factor.py +++ b/beliefs/factors/discrete_factor.py @@ -86,9 +86,16 @@ class DiscreteFactor: right = copy.deepcopy(other) left.add_new_variables_from_other_factor(right) right.add_new_variables_from_other_factor(left) + print('var', left.variables) + print(left.cardinality) + print(left.values) + print('var', right.variables) + print(right.cardinality) + print(right.values) # reorder variables in right factor to match order in left source_axes = list(range(right.values.ndim)) + print('source_axes', source_axes) destination_axes = [right.variables.index(var) for var in left.variables] right.variables = [right.variables[idx] for idx in destination_axes] diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index cd8ba8c..17e98fa 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -205,25 +205,30 @@ class Node: return msgs def compute_pi_agg(self): - # TODO: implement explict factor product operation - raise NotImplementedError + if len(self.parents) == 0: + self.update_pi_agg(self.cpd.values) + else: + factors_to_multiply = [self.cpd] + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + factors_to_multiply.extend(pi_msgs) + + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + self.update_pi_agg(factor_product.marginalize(self.parents).values) + pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) def compute_lambda_agg(self): - if len(self.children) == 0: - return self.lambda_agg.values - else: + if len(self.children) != 0: lambda_msg_values = [ msg.values for msg in self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) ] self.update_lambda_agg(reduce(np.multiply, lambda_msg_values)) - return self.lambda_agg.values def update_pi_agg(self, new_value): - self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality)) + self.pi_agg.update_values(new_value) def update_lambda_agg(self, new_value): - self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality)) + self.lambda_agg.update_values(new_value) def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type): if key not in received_msg_dict.keys(): @@ -242,7 +247,8 @@ class Node: if new_value.shape != expected_shape: raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" .format(expected_shape, new_value.shape)) - received_msg_dict[key]._values = new_value + # received_msg_dict[key]._values = new_value + received_msg_dict[key].update_values(new_value) def update_pi_msg_from_parent(self, parent, new_value): self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs, @@ -267,8 +273,17 @@ class Node: raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.") def compute_lambda_msg_to_parent(self, parent_k): - # TODO: implement explict factor product operation - raise NotImplementedError + if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + factors_to_multiply = [self.cpd] + pi_msgs_excl_k = [msg for par_id, msg in self.pi_received_msgs.items() + if par_id != parent_k] + factors_to_multiply.extend(pi_msgs_excl_k) + factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply) + new_factor = factor_product.marginalize(list(set(self.parents) - set([parent_k]))) + lambda_msg_to_k = (self.lambda_agg * new_factor).marginalize([self.lambda_agg.variables[0]]) + return self._normalize(lambda_msg_to_k.values) @property def is_fully_initialized(self): diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py index 7a77311..1b8c0ac 100644 --- a/tests/test_belief_propagation.py +++ b/tests/test_belief_propagation.py @@ -3,10 +3,12 @@ import pytest from pytest import approx from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError +from beliefs.factors.cpd import TabularCPD from beliefs.models.belief_update_node_model import ( BeliefUpdateNodeModel, BernoulliOrNode, - BernoulliAndNode + BernoulliAndNode, + Node ) @@ -89,6 +91,41 @@ def mixed_cpd_model(edges_five_nodes): 'w': w_node}) +@pytest.fixture(scope='function') +def custom_cpd_model(): + """ + Y-shaped model, with parents ,'u' and 'v' as Or-nodes, 'x' a node with + cardinality 3 and custom CPD, 'y' a node with cardinality 2 and custom CPD. + """ + custom_cpd_x = TabularCPD(variable='x', + variable_card=3, + parents=['u', 'v'], + parents_card=[2, 2], + values=[[0.2, 0, 0.3, 0.1], + [0.4, 1, 0.7, 0.2], + [0.4, 0, 0, 0.7]], + state_names={'x': ['lo', 'med', 'hi'], + 'u': ['False', 'True'], + 'v': ['False', 'True']}) + custom_cpd_y = TabularCPD(variable='y', + variable_card=2, + parents=['x'], + parents_card=[3], + values=[[0.3, 0.1, 0], + [0.7, 0.9, 1]], + state_names={'x': ['lo', 'med', 'hi'], + 'y': ['False', 'True']}) + + u_node = BernoulliOrNode(label_id='u', children=['x'], parents=[]) + v_node = BernoulliOrNode(label_id='v', children=['x'], parents=[]) + x_node = Node(children=['y'], cpd=custom_cpd_x) + y_node = Node(children=[], cpd=custom_cpd_y) + return BeliefUpdateNodeModel(nodes_dict={'u': u_node, + 'v': v_node, + 'x': x_node, + 'y': y_node}) + + def get_label_mapped_to_positive_belief(query_result): """Return a dictionary mapping each label_id to the probability of the label being True.""" @@ -355,3 +392,28 @@ def test_conflicting_evidence_and_model(many_parents_and_model): with pytest.raises(ConflictingEvidenceError) as err: query_result = infer.query(evidence={'62': np.array([0, 1]), '112': np.array([1, 0])}) assert "Can't run belief propagation with conflicting evidence" in str(err) + + +#============================================================================================== +# Model with two custom cpds + + +def test_no_evidence_custom_cpd_model(custom_cpd_model): + expected = {'x': np.array([0.15, 0.575, 0.275]), + 'v': np.array([0.5, 0.5]), + 'u': np.array([0.5, 0.5]), + 'y': np.array([0.1025, 0.8975])} + infer = BeliefPropagation(custom_cpd_model) + query_result = infer.query(evidence={}) + compare_dictionaries(expected, query_result) + + +def test_evidence_custom_cpd_model(custom_cpd_model): + """Custom node is observed to be in 'med' state.""" + expected = {'x': np.array([0., 1., 0.]), + 'u': np.array([0.60869565, 0.39130435]), + 'v': np.array([0.47826087, 0.52173913]), + 'y': np.array([0.1, 0.9])} + infer = BeliefPropagation(custom_cpd_model) + query_result = infer.query(evidence={'x': np.array([0, 1, 0])}) + compare_dictionaries(expected, query_result) -- cgit v1.2.3 From 7053fefc6f9e43b1e252d1f551401a7a70b52e93 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Wed, 13 Dec 2017 18:47:32 -0800 Subject: cleanup print statements, stale comments, minor TODOs --- beliefs/factors/bernoulli_and_cpd.py | 7 +- beliefs/factors/bernoulli_or_cpd.py | 7 +- beliefs/factors/cpd.py | 29 ++-- beliefs/factors/discrete_factor.py | 32 ++-- beliefs/inference/belief_propagation.py | 77 +++++----- beliefs/models/base_models.py | 90 ++++++----- beliefs/models/belief_update_node_model.py | 238 +++++++++++++++++++---------- beliefs/utils/math_helper.py | 14 +- beliefs/utils/random_variables.py | 17 ++- 9 files changed, 306 insertions(+), 205 deletions(-) (limited to 'beliefs') diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py index 15802c2..291398f 100644 --- a/beliefs/factors/bernoulli_and_cpd.py +++ b/beliefs/factors/bernoulli_and_cpd.py @@ -12,9 +12,10 @@ class BernoulliAndCPD(TabularCPD): """ def __init__(self, variable, parents=[]): """ - Args: - variable: int or string - parents: optional, list of int and/or strings + Args + variable: int or string + parents: list, + (optional) list of int and/or strings """ super().__init__(variable=variable, variable_card=2, diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py index 5b661a1..b5e6ae5 100644 --- a/beliefs/factors/bernoulli_or_cpd.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -12,9 +12,10 @@ class BernoulliOrCPD(TabularCPD): """ def __init__(self, variable, parents=[]): """ - Args: - variable: int or string - parents: optional, list of int and/or strings + Args + variable: int or string + parents: list, + (optional) list of int and/or strings """ super().__init__(variable=variable, variable_card=2, diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py index 9e7191f..c7883c9 100644 --- a/beliefs/factors/cpd.py +++ b/beliefs/factors/cpd.py @@ -1,3 +1,4 @@ +import copy import numpy as np from beliefs.factors.discrete_factor import DiscreteFactor @@ -7,16 +8,18 @@ class TabularCPD(DiscreteFactor): Defines the conditional probability table for a discrete variable whose parents are also discrete. """ - def __init__(self, variable, variable_card, - parents=[], parents_card=[], values=[], state_names=None): + def __init__(self, variable, variable_card, parents=[], parents_card=[], + values=[], state_names=None): """ - Args: - variable: int or string - variable_card: int - parents: optional, list of int and/or strings - parents_card: optional, list of int - values: optional, 2d list or array - state_names: dictionary (optional), + Args + variable: int or string + variable_card: int + parents: list, + (optional) list of int and/or strings + parents_card: list, + (optional) list of int + values: 2-d list or array (optional) + state_names: dictionary (optional), mapping variables to their states, of format {label_name: ['state1', 'state2']} """ super().__init__(variables=[variable] + parents, @@ -24,7 +27,7 @@ class TabularCPD(DiscreteFactor): values=values, state_names=state_names) self.variable = variable - self.parents = parents + self.parents = list(parents) def get_values(self): """ @@ -36,8 +39,4 @@ class TabularCPD(DiscreteFactor): return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:])) def copy(self): - return self.__class__(self.variable, - self.cardinality[0], - self.parents, - self.cardinality[1:], - self._values) + return copy.deepcopy(self) diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py index b75da28..708f00c 100644 --- a/beliefs/factors/discrete_factor.py +++ b/beliefs/factors/discrete_factor.py @@ -18,7 +18,7 @@ class DiscreteFactor: mapping variables to their states, of format {label_name: ['state1', 'state2']} """ self.variables = list(variables) - self.cardinality = cardinality + self.cardinality = list(cardinality) if values is None: self._values = None else: @@ -28,6 +28,13 @@ class DiscreteFactor: def __mul__(self, other): return self.product(other) + def copy(self): + """Return a copy of the factor""" + return self.__class__(self.variables, + self.cardinality, + self._values, + copy.deepcopy(self.state_names)) + @property def values(self): return self._values @@ -56,7 +63,7 @@ class DiscreteFactor: return self.values[tuple(state_coordinates)] def add_new_variables_from_other_factor(self, other): - """Add new variables to the factor.""" + """Add new variables from `other` factor to the factor.""" extra_vars = set(other.variables) - set(self.variables) # if all of these variables already exist there is nothing to do if len(extra_vars) == 0: @@ -69,33 +76,24 @@ class DiscreteFactor: new_card_var = other.get_cardinality(extra_vars) self.cardinality.extend([new_card_var[var] for var in extra_vars]) - return def get_cardinality(self, variables): return {var: self.cardinality[self.variables.index(var)] for var in variables} def product(self, other): - left = copy.deepcopy(self) + left = self.copy() if isinstance(other, (int, float)): - # TODO: handle case of multiplication by constant - pass + return self.values * other else: - # assert right is a class or subclass of DiscreteFactor - # that has attributes: variables, values; method: get_cardinality - right = copy.deepcopy(other) + assert isinstance(other, DiscreteFactor), \ + "__mul__ is only defined between subclasses of DiscreteFactor" + right = other.copy() left.add_new_variables_from_other_factor(right) right.add_new_variables_from_other_factor(left) - print('var', left.variables) - print(left.cardinality) - print(left.values) - print('var', right.variables) - print(right.cardinality) - print(right.values) # reorder variables in right factor to match order in left source_axes = list(range(right.values.ndim)) - print('source_axes', source_axes) destination_axes = [right.variables.index(var) for var in left.variables] right.variables = [right.variables[idx] for idx in destination_axes] @@ -110,7 +108,7 @@ class DiscreteFactor: vars: list, variables over which to marginalize the factor Returns - DiscreteFactor + DiscreteFactor, whose scope is set(self.variables) - set(vars) """ phi = copy.deepcopy(self) diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index 128f645..acd93d4 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -28,10 +28,10 @@ class ConflictingEvidenceError(Exception): class BeliefPropagation: def __init__(self, model, inplace=True): """ - Input: - model: an instance of BeliefUpdateNodeModel - inplace: bool - modify in-place the nodes in the model during belief propagation + Args + model: an instance of BeliefUpdateNodeModel + inplace: bool, + modify in-place the nodes in the model during belief propagation """ if not isinstance(model, BeliefUpdateNodeModel): raise TypeError("Model must be an instance of BeliefUpdateNodeModel") @@ -43,21 +43,20 @@ class BeliefPropagation: def _belief_propagation(self, nodes_to_update, evidence): """ Implementation of Pearl's belief propagation algorithm for polytrees. - ref: "Fusion, Propagation, and Structuring in Belief Networks" Artificial Intelligence 29 (1986) 241-288 - Input: - nodes_to_update: list - list of MsgPasser namedtuples. - evidence: dict, - a dict key, value pair as {var: state_of_var observed} + Args + nodes_to_update: list, + list of MsgPasser namedtuples. + evidence: dict, + a dict key, value pair as {var: state_of_var observed} """ if len(nodes_to_update) == 0: return node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop() - logging.info("Node: %s", node_to_update_label_id) + logging.debug("Node: %s", node_to_update_label_id) node = self.model.nodes_dict[node_to_update_label_id] @@ -65,8 +64,8 @@ class BeliefPropagation: # outgoing msg from the node to update parent_ids = set(node.parents) - set([msg_sender_label_id]) child_ids = set(node.children) - set([msg_sender_label_id]) - logging.info("parent_ids: %s", str(parent_ids)) - logging.info("child_ids: %s", str(child_ids)) + logging.debug("parent_ids: %s", str(parent_ids)) + logging.debug("child_ids: %s", str(child_ids)) if msg_sender_label_id is not None: # update triggered by receiving a message, not pinning to evidence @@ -74,9 +73,9 @@ class BeliefPropagation: if node_to_update_label_id not in evidence: node.compute_pi_agg() - logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values)) + logging.debug("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values)) node.compute_lambda_agg() - logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values)) + logging.debug("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values)) for parent_id in parent_ids: try: @@ -101,14 +100,14 @@ class BeliefPropagation: def initialize_model(self): """ - Apply boundary conditions: + 1. Apply boundary conditions: - Set pi_agg equal to prior probabilities for root nodes. - Set lambda_agg equal to vector of ones for leaf nodes. - - Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as - actually passing lambda messages up from leaf nodes to root nodes). - - Calculate pi_agg and pi_received_msgs for all nodes without evidence. - (Without evidence, belief equals pi_agg.) + 2. Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as + actually passing lambda messages up from leaf nodes to root nodes). + 3. Calculate pi_agg and pi_received_msgs for all nodes without evidence. + (Without evidence, belief equals pi_agg.) """ self.model.set_boundary_conditions() @@ -119,13 +118,13 @@ class BeliefPropagation: for child in node.lambda_received_msgs.keys(): node.update_lambda_msg_from_child(child=child, new_value=ones_vector) - logging.info("Finished initializing Lambda(x) and lambda_received_msgs per node.") + logging.debug("Finished initializing Lambda(x) and lambda_received_msgs per node.") - logging.info("Start downward sweep from nodes. Sending Pi messages only.") + logging.debug("Start downward sweep from nodes. Sending Pi messages only.") topdown_order = self.model.get_topologically_sorted_nodes(reverse=False) for node_id in topdown_order: - logging.info('label in iteration through top-down order: %s', str(node_id)) + logging.debug('label in iteration through top-down order: %s', str(node_id)) node_sending_msg = self.model.nodes_dict[node_id] child_ids = node_sending_msg.children @@ -134,9 +133,9 @@ class BeliefPropagation: node_sending_msg.compute_pi_agg() for child_id in child_ids: - logging.info("child: %s", str(child_id)) + logging.debug("child: %s", str(child_id)) new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id) - logging.info("new_pi_msg: %s", np.array2string(new_pi_msg)) + logging.debug("new_pi_msg: %s", np.array2string(new_pi_msg)) child_node = self.model.nodes_dict[child_id] child_node.update_pi_msg_from_parent(parent=node_id, @@ -144,9 +143,12 @@ class BeliefPropagation: def _run_belief_propagation(self, evidence): """ - Input: - evidence: dict - a dict key, value pair as {var: state_of_var observed} + Sequentially perturb nodes with observed values, running belief propagation + after each perturbation. + + Args + evidence: dict, + a dict key, value pair as {var: state_of_var observed} """ for evidence_id, observed_value in evidence.items(): if evidence_id not in self.model.nodes_dict.keys(): @@ -162,21 +164,20 @@ class BeliefPropagation: self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value ) nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)] - self._belief_propagation(nodes_to_update=set(nodes_to_update), - evidence=evidence) + self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence) def query(self, evidence={}): """ - Run belief propagation given evidence. + Run belief propagation given 0 or more pieces of evidence. - Input: - evidence: dict - a dict key, value pair as {var: state_of_var observed}, - e.g. {'3': np.array([0,1])} if label '3' is True. + Args + evidence: dict, + a dict key, value pair as {var: state_of_var observed}, + e.g. {'3': np.array([0,1])} if label '3' is True. - Returns: - beliefs: dict - a dict key, value pair as {var: belief} + Returns + a dict key, value pair as {var: belief}, where belief is an np.array of the + marginal probability of each state of the variable given the evidence. Example ------- diff --git a/beliefs/models/base_models.py b/beliefs/models/base_models.py index cb91566..71af0cb 100644 --- a/beliefs/models/base_models.py +++ b/beliefs/models/base_models.py @@ -9,9 +9,11 @@ class DirectedGraph(nx.DiGraph): """ def __init__(self, edges=None, node_labels=None): """ - Input: - edges: an edge list, e.g. [(parent1, child1), (parent1, child2)] - node_labels: a list of strings of node labels + Args + edges: list, + a list of edge tuples, e.g. [(parent1, child1), (parent1, child2)] + node_labels: list, + a list of strings or integers representing node label ids """ super().__init__() if edges is not None: @@ -20,18 +22,15 @@ class DirectedGraph(nx.DiGraph): self.add_nodes_from(node_labels) def get_leaves(self): - """ - Returns a list of leaves of the graph. - """ + """Return a list of leaves of the graph""" return [node for node, out_degree in self.out_degree() if out_degree == 0] def get_roots(self): - """ - Returns a list of roots of the graph. - """ + """Return a list of roots of the graph""" return [node for node, in_degree in self.in_degree() if in_degree == 0] def get_topologically_sorted_nodes(self, reverse=False): + """Return a list of nodes in topological sort order""" if reverse: return list(reversed(list(nx.topological_sort(self)))) else: @@ -47,12 +46,12 @@ class BayesianModel(DirectedGraph): """ Base class for Bayesian model. - Input: - edges: (optional) list of edges, + Args + edges: (optional) list of edges, tuples of form ('parent', 'child') - variables: (optional) list of str or int + variables: (optional) list of str or int labels for variables - cpds: (optional) list of CPDs + cpds: (optional) list of CPDs TabularCPD class or subclass """ super().__init__() @@ -61,20 +60,17 @@ class BayesianModel(DirectedGraph): self.cpds = cpds def copy(self): - """ - Returns a copy of the model. - """ - copy_model = self.__class__(edges=list(self.edges()).copy(), - variables=list(self.nodes()).copy(), - cpds=[cpd.copy() for cpd in self.cpds]) - return copy_model + """Return a copy of the model""" + return self.__class__(edges=list(self.edges()).copy(), + variables=list(self.nodes()).copy(), + cpds=[cpd.copy() for cpd in self.cpds]) def get_variables_in_definite_state(self): """ - Returns a set of labels of all nodes in a definite state, i.e. with - label values that are kronecker deltas. + Get labels of all nodes in a definite state, i.e. with label values + that are kronecker deltas. - RETURNS + Returns set of strings (labels) """ return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)} @@ -84,14 +80,14 @@ class BayesianModel(DirectedGraph): Returns a set of labels that are inferred to be in definite state, given list of labels that were directly observed (e.g. YES/NOs, but not MAYBEs). - INPUT - observed: set of strings, directly observed labels - RETURNS - set of strings, labels inferred to be in a definite state + Args + observed: set, + set of strings, directly observed labels + Returns + set of strings, the labels inferred to be in a definite state """ - - # Assert that beliefs of directly observed vars are kronecker deltas for label in observed: + # beliefs of directly observed vars should be kronecker deltas assert is_kronecker_delta(self.nodes_dict[label].belief), \ ("Observed label has belief {} but should be kronecker delta" .format(self.nodes_dict[label].belief)) @@ -101,28 +97,40 @@ class BayesianModel(DirectedGraph): "Expected set of observed labels to be a subset of labels in definite state." return vars_in_definite_state - observed - def _get_ancestors_of(self, observed): - """Return list of ancestors of observed labels""" + def _get_ancestors_of(self, labels): + """ + Get set of ancestors of an iterable of labels. + + Args + observed: iterable, + label ids for which ancestors should be retrieved + + Returns + ancestors: set, + set of label ids of ancestors of the input labels + """ ancestors = set() - for label in observed: + for label in labels: ancestors.update(nx.ancestors(self, label)) return ancestors def reachable_observed_variables(self, source, observed=set()): """ - Returns list of observed labels (labels with direct evidence to be in a definite + Get list of directly observed labels (labels with evidence in a definite state) that are reachable from the source. - INPUT - source: string, label of node for which to evaluate reachable observed labels - observed: set of strings, directly observed labels - RETURNS - reachable_observed_vars: set of strings, observed labels (variables with direct - evidence) that are reachable from the source label. + Args + source: string, + label of node for which to evaluate reachable observed labels + observed: set, + set of strings, directly observed labels + Returns + reachable_observed_vars: set, + set of strings, observed labels (variables with direct evidence) + that are reachable from the source label """ - # ancestors of observed labels, including observed labels ancestors_of_observed = self._get_ancestors_of(observed) - ancestors_of_observed.update(observed) + ancestors_of_observed.update(observed) # include observed labels visit_list = set() visit_list.add((source, 'up')) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 17e98fa..1a9ab19 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -33,9 +33,9 @@ class BeliefUpdateNodeModel(BayesianModel): """ def __init__(self, nodes_dict): """ - Input: - nodes_dict: dict - a dict key, value pair as {label_id: instance_of_node_class_or_subclass} + Args + nodes_dict: dict + a dict key, value pair as {label_id: instance_of_node_class_or_subclass} """ super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()), variables=list(nodes_dict.keys()), @@ -45,12 +45,15 @@ class BeliefUpdateNodeModel(BayesianModel): @classmethod def init_from_edges(cls, edges, node_class): - """Create nodes from the same node class. + """ + Create model from edges where all nodes are a from the same node class. - Input: - edges: list of edge tuples of form ('parent', 'child') - node_class: the Node class or subclass from which to - create all the nodes from edges. + Args + edges: list, + list of edge tuples of form [('parent', 'child')] + node_class: Node class or subclass, + class from which to create all the nodes automatically from edges, + e.g. BernoulliAndNode or BernoulliOrNode """ nodes = set() g = nx.DiGraph(edges) @@ -68,10 +71,12 @@ class BeliefUpdateNodeModel(BayesianModel): """ Return list of all directed edges in nodes. - Args: - nodes: an iterable of objects of the Node class or subclass - Returns: - edges: list of edge tuples + Args + nodes: iterable, + iterable of objects of the Node class or subclass + Returns + edges: list, + list of edge tuples """ edges = set() for node in nodes: @@ -82,11 +87,13 @@ class BeliefUpdateNodeModel(BayesianModel): def set_boundary_conditions(self): """ - 1. Root nodes: if x is a node with no parents, set Pi(x) = prior - probability of x. + Set boundary conditions for nodes in the model. + + 1. Root nodes: if x is a node with no parents, set Pi(x) = prior + probability of x. - 2. Leaf nodes: if x is a node with no children, set Lambda(x) - to an (unnormalized) unit vector, of length the cardinality of x. + 2. Leaf nodes: if x is a node with no children, set Lambda(x) + to an (unnormalized) unit vector, of length the cardinality of x. """ for root in self.get_roots(): self.nodes_dict[root].update_pi_agg(self.nodes_dict[root].cpd.values) @@ -97,8 +104,11 @@ class BeliefUpdateNodeModel(BayesianModel): @property def all_nodes_are_fully_initialized(self): """ - Returns True if, for all nodes in the model, all lambda and pi - messages and lambda_agg and pi_agg are not None, else False. + Check if all nodes in the model are initialized, i.e. lambda and pi messages and + lambda_agg and pi_agg are not None for every node. + + Returns + bool, True if all nodes in the model are initialized, else False. """ for node in self.nodes_dict.values(): if not node.is_fully_initialized: @@ -106,27 +116,27 @@ class BeliefUpdateNodeModel(BayesianModel): return True def copy(self): - """ - Returns a copy of the model. - """ + """Return a copy of the model.""" copy_nodes = copy.deepcopy(self.nodes_dict) copy_model = self.__class__(nodes_dict=copy_nodes) return copy_model class Node: - """A node in a DAG with methods to compute the belief (marginal probability - of the node given evidence) and compute pi/lambda messages to/from its neighbors + """ + A node in a DAG with methods to compute the belief (marginal probability of + the node given evidence) and compute pi/lambda messages to/from its neighbors to update its belief. - Implemented from Pearl's belief propagation algorithm. + Implemented from Pearl's belief propagation algorithm for polytrees. """ def __init__(self, children, cpd): """ Args - children: list of strings - cpd: an instance of a conditional probability distribution, - e.g. BernoulliOrCPD or TabularCPD + children: list, + list of strings + cpd: an instance of TabularCPD or one of its subclasses, + e.g. BernoulliOrCPD or BernoulliAndCPD """ self.label_id = cpd.variable self.children = children @@ -134,15 +144,20 @@ class Node: self.cardinality = cpd.cardinality[0] self.cpd = cpd - # instances of DiscreteFactor with `values` an np.array of dimensions [1, cardinality] - self.pi_agg = self._init_aggregate_values() - self.lambda_agg = self._init_aggregate_values() + self.pi_agg = self._init_factor_for_variable() + self.lambda_agg = self._init_factor_for_variable() self.pi_received_msgs = self._init_pi_received_msgs(self.parents) - self.lambda_received_msgs = {child: self._init_aggregate_values() for child in children} + self.lambda_received_msgs = {child: self._init_factor_for_variable() for child in children} @property def belief(self): + """ + Calculate the marginal probability of the variable from its aggregate values. + + Returns + belief, an np.array of ndim 1 and shape (self.cardinality,) + """ if any(self.pi_agg.values) and any(self.lambda_agg.values): belief = (self.lambda_agg * self.pi_agg).values return self._normalize(belief) @@ -152,29 +167,50 @@ class Node: def _normalize(self, value): return value/value.sum() - def _init_aggregate_values(self): + def _init_factor_for_variable(self): + """ + Returns + instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of + ndim 1 and shape (self.cardinality,) + """ return DiscreteFactor(variables=[self.cpd.variable], cardinality=[self.cardinality], values=None, state_names=None) def _init_pi_received_msgs(self, parents): + """ + Args + parents: list, + list of strings, parent ids of the node + Returns + msgs: dict, + a dict with key, value pair as {parent_id: instance of a DiscreteFactor}, + where DiscreteFactor.values is an np.array of ndim 1 and + shape (cardinality of parent_id,) + """ msgs = {} for k in parents: + if self.cpd.state_names is not None: + state_names = {k: self.cpd.state_names[k]} + else: + state_names = None + kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] msgs[k] = DiscreteFactor(variables=[k], cardinality=[kth_cardinality], values=None, - state_names=None) + state_names=state_names) return msgs def _return_msgs_received_for_msg_type(self, message_type): """ - Input: - message_type: MessageType enum - - Returns: - msg_values: list of DiscreteFactors containing message values (np.arrays) + Args + message_type: MessageType enum + Returns + msg_values: list, + list of DiscreteFactors with property `values` containing + the values of the messages (np.arrays) """ if message_type == MessageType.LAMBDA: msgs = [msg for msg in self.lambda_received_msgs.values()] @@ -188,11 +224,12 @@ class Node: Raise error if all messages have not been received. Called before calculating lambda_agg (pi_agg). - Input: - message_type: MessageType enum - - Returns: - msgs: list of DiscreteFactors containing message values (np.array) + Args + message_type: MessageType enum + Returns + msgs: list, + list of DiscreteFactors with property `values` containing + the values of the messages (np.arrays) """ msgs = self._return_msgs_received_for_msg_type(message_type) @@ -205,6 +242,10 @@ class Node: return msgs def compute_pi_agg(self): + """ + Compute and update pi_agg, the prior probability, given the current state + of messages received from parents. + """ if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: @@ -217,6 +258,10 @@ class Node: pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) def compute_lambda_agg(self): + """ + Compute and update lambda_agg, the likelihood, given the current state + of messages received from children. + """ if len(self.children) != 0: lambda_msg_values = [ msg.values for msg in @@ -245,9 +290,8 @@ class Node: expected_shape = (self.cpd.cardinality[self.cpd.variables.index(key)],) if new_value.shape != expected_shape: - raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" - .format(expected_shape, new_value.shape)) - # received_msg_dict[key]._values = new_value + raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" + .format(expected_shape, new_value.shape)) received_msg_dict[key].update_values(new_value) def update_pi_msg_from_parent(self, parent, new_value): @@ -263,6 +307,15 @@ class Node: message_type=MessageType.LAMBDA) def compute_pi_msg_to_child(self, child_k): + """ + Compute pi_msg to child. + + Args + child_k: string or int, + the label_id of the child receiving the pi_msg + Returns + np.array of ndim 1 and shape (self.cardinality,) + """ lambda_msg_from_child = self.lambda_received_msgs[child_k].values if lambda_msg_from_child is not None: with np.errstate(divide='ignore', invalid='ignore'): @@ -273,6 +326,15 @@ class Node: raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.") def compute_lambda_msg_to_parent(self, parent_k): + """ + Compute lambda_msg to parent. + + Args + parent_k: string or int, + the label_id of the parent receiving the lambda_msg + Returns + np.array of ndim 1 and shape (cardinality of parent_k,) + """ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: @@ -306,30 +368,31 @@ class Node: class BernoulliOrNode(Node): - def __init__(self, - label_id, - children, - parents): + """ + A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True'] + and conditional probability distribution described by 'Or' logic. + """ + def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) - def _init_aggregate_values(self): + def _init_factor_for_variable(self): + """ + Returns + instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of + ndim 1 and shape (self.cardinality,) + """ variable = self.cpd.variable return DiscreteFactor(variables=[self.cpd.variable], cardinality=[self.cardinality], values=None, state_names={variable: self.cpd.state_names[variable]}) - def _init_pi_received_msgs(self, parents): - msgs = {} - for k in parents: - kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] - msgs[k] = DiscreteFactor(variables=[k], - cardinality=[kth_cardinality], - values=None, - state_names={k: self.cpd.state_names[k]}) - return msgs - def compute_pi_agg(self): + """ + Compute and update pi_agg, the prior probability, given the current state + of messages received from parents. Sidestep explicit factor product and + marginalization. + """ if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: @@ -339,9 +402,18 @@ class BernoulliOrNode(Node): p_0 = reduce(lambda x, y: x*y, parents_p0) p_1 = 1 - p_0 self.update_pi_agg(np.array([p_0, p_1])) - return self.pi_agg def compute_lambda_msg_to_parent(self, parent_k): + """ + Compute lambda_msg to parent. Sidestep explicit factor product and + marginalization. + + Args + parent_k: string or int, + the label_id of the parent receiving the lambda_msg + Returns + np.array of ndim 1 and shape (cardinality of parent_k,) + """ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: @@ -362,30 +434,31 @@ class BernoulliOrNode(Node): class BernoulliAndNode(Node): - def __init__(self, - label_id, - children, - parents): + """ + A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True'] + and conditional probability distribution described by 'And' logic. + """ + def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) - def _init_aggregate_values(self): + def _init_factor_for_variable(self): + """ + Returns + instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of + ndim 1 and shape (self.cardinality,) + """ variable = self.cpd.variable return DiscreteFactor(variables=[self.cpd.variable], cardinality=[self.cardinality], values=None, state_names={variable: self.cpd.state_names[variable]}) - def _init_pi_received_msgs(self, parents): - msgs = {} - for k in parents: - kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] - msgs[k] = DiscreteFactor(variables=[k], - cardinality=[kth_cardinality], - values=None, - state_names={k: self.cpd.state_names[k]}) - return msgs - def compute_pi_agg(self): + """ + Compute and update pi_agg, the prior probability, given the current state + of messages received from parents. Sidestep explicit factor product and + marginalization. + """ if len(self.parents) == 0: self.update_pi_agg(self.cpd.values) else: @@ -395,9 +468,18 @@ class BernoulliAndNode(Node): p_1 = reduce(lambda x, y: x*y, parents_p1) p_0 = 1 - p_1 self.update_pi_agg(np.array([p_0, p_1])) - return self.pi_agg def compute_lambda_msg_to_parent(self, parent_k): + """ + Compute lambda_msg to parent. Sidestep explicit factor product and + marginalization. + + Args + parent_k: string or int, + the label_id of the parent receiving the lambda_msg + Returns + np.array of ndim 1 and shape (cardinality of parent_k,) + """ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])): return np.ones([self.cardinality]) else: diff --git a/beliefs/utils/math_helper.py b/beliefs/utils/math_helper.py index a25ea68..12325e1 100644 --- a/beliefs/utils/math_helper.py +++ b/beliefs/utils/math_helper.py @@ -1,10 +1,16 @@ -"""Random math utils.""" +"""Math utils""" def is_kronecker_delta(vector): - """Returns True if vector is a kronecker delta vector, False otherwise. - Specific evidence ('YES' or 'NO') is a kronecker delta vector, whereas - virtual evidence ('MAYBE') is not. + """ + Check if vector is a kronecker delta. + + Args: + vector: iterable of numbers + Returns: + bool, True if vector is a kronecker delta vector, False otherwise. + In belief propagation, specific evidence (variable is directly observed) + is a kronecker delta vector, but virtual evidence is not. """ count = 0 for x in vector: diff --git a/beliefs/utils/random_variables.py b/beliefs/utils/random_variables.py index 1a0b0f7..cad07aa 100644 --- a/beliefs/utils/random_variables.py +++ b/beliefs/utils/random_variables.py @@ -1,3 +1,4 @@ +"""Utilities for working with models and random variables.""" def get_reachable_observed_variables_for_inferred_variables(model, observed=set()): @@ -6,12 +7,16 @@ def get_reachable_observed_variables_for_inferred_variables(model, observed=set( ("reachable observed variables") that influenced the beliefs of variables inferred to be in a definite state. - INPUT - model: instance of BayesianModel class or subclass - observed: set of labels (strings) corresponding to vars pinned to definite - state during inference. - RETURNS - dict, of form key - source label (a string), value - a list of strings + Args + model: instance of BayesianModel class or subclass + observed: set, + set of labels (strings) corresponding to variables pinned to a definite + state during inference. + Returns + dict, + key, value pairs {source_label_id: reachable_observed_vars}, where + source_label_id is an int or string, and reachable_observed_vars is a list + of label_ids """ if not observed: return {} -- cgit v1.2.3 From d92ed9f14baead60fdd6c1d823345cc3ddd1bc04 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Wed, 17 Jan 2018 15:04:57 -0800 Subject: consolidate Node methods _init_factor_for_variable, _init_pi_received_msgs, into a single method _init_factors_for_variables overrides of _init_factor_for_variable for BernoulliOr/AndNode not needed --- beliefs/models/belief_update_node_model.py | 79 ++++++++++-------------------- 1 file changed, 25 insertions(+), 54 deletions(-) (limited to 'beliefs') diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 1a9ab19..743bbcb 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -144,11 +144,14 @@ class Node: self.cardinality = cpd.cardinality[0] self.cpd = cpd - self.pi_agg = self._init_factor_for_variable() - self.lambda_agg = self._init_factor_for_variable() + self.pi_agg = self._init_factors_for_variables([self.label_id])[self.label_id] + self.lambda_agg = self._init_factors_for_variables([self.label_id])[self.label_id] + + self.pi_received_msgs = self._init_factors_for_variables(self.parents) + self.lambda_received_msgs = \ + {child: self._init_factors_for_variables([self.label_id])[self.label_id] + for child in children} - self.pi_received_msgs = self._init_pi_received_msgs(self.parents) - self.lambda_received_msgs = {child: self._init_factor_for_variable() for child in children} @property def belief(self): @@ -167,41 +170,33 @@ class Node: def _normalize(self, value): return value/value.sum() - def _init_factor_for_variable(self): - """ - Returns - instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of - ndim 1 and shape (self.cardinality,) - """ - return DiscreteFactor(variables=[self.cpd.variable], - cardinality=[self.cardinality], - values=None, - state_names=None) - - def _init_pi_received_msgs(self, parents): + def _init_factors_for_variables(self, variables): """ Args - parents: list, - list of strings, parent ids of the node + variables: list, + list of ints/strings, e.g. the single node variable or list + of parent ids of the node Returns - msgs: dict, - a dict with key, value pair as {parent_id: instance of a DiscreteFactor}, + factors: dict, + where the dict has key, value pair as {variable_id: instance of a DiscreteFactor}, where DiscreteFactor.values is an np.array of ndim 1 and - shape (cardinality of parent_id,) + shape (cardinality of variable_id,) """ - msgs = {} - for k in parents: + variables = list(variables) + factors = {} + + for var in variables: if self.cpd.state_names is not None: - state_names = {k: self.cpd.state_names[k]} + state_names = {var: self.cpd.state_names[var]} else: state_names = None - kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] - msgs[k] = DiscreteFactor(variables=[k], - cardinality=[kth_cardinality], - values=None, - state_names=state_names) - return msgs + cardinality = self.cpd.cardinality[self.cpd.variables.index(var)] + factors[var] = DiscreteFactor(variables=[var], + cardinality=[cardinality], + values=None, + state_names=state_names) + return factors def _return_msgs_received_for_msg_type(self, message_type): """ @@ -375,18 +370,6 @@ class BernoulliOrNode(Node): def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) - def _init_factor_for_variable(self): - """ - Returns - instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of - ndim 1 and shape (self.cardinality,) - """ - variable = self.cpd.variable - return DiscreteFactor(variables=[self.cpd.variable], - cardinality=[self.cardinality], - values=None, - state_names={variable: self.cpd.state_names[variable]}) - def compute_pi_agg(self): """ Compute and update pi_agg, the prior probability, given the current state @@ -441,18 +424,6 @@ class BernoulliAndNode(Node): def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) - def _init_factor_for_variable(self): - """ - Returns - instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of - ndim 1 and shape (self.cardinality,) - """ - variable = self.cpd.variable - return DiscreteFactor(variables=[self.cpd.variable], - cardinality=[self.cardinality], - values=None, - state_names={variable: self.cpd.state_names[variable]}) - def compute_pi_agg(self): """ Compute and update pi_agg, the prior probability, given the current state -- cgit v1.2.3 From 1a9286b0c1698fe5329a8e0b2a886f0a98286d2b Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Wed, 17 Jan 2018 17:32:07 -0800 Subject: compute_pi_agg -> compute_and_update_pi_agg, compute_lambda_agg -> compute_and_update_lambda_agg --- beliefs/inference/belief_propagation.py | 6 +++--- beliefs/models/belief_update_node_model.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) (limited to 'beliefs') diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index acd93d4..e6e7b18 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -72,9 +72,9 @@ class BeliefPropagation: assert len(node.parents) + len(node.children) - 1 == len(parent_ids) + len(child_ids) if node_to_update_label_id not in evidence: - node.compute_pi_agg() + node.compute_and_update_pi_agg() logging.debug("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values)) - node.compute_lambda_agg() + node.compute_and_update_lambda_agg() logging.debug("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values)) for parent_id in parent_ids: @@ -130,7 +130,7 @@ class BeliefPropagation: child_ids = node_sending_msg.children if node_sending_msg.pi_agg.values is None: - node_sending_msg.compute_pi_agg() + node_sending_msg.compute_and_update_pi_agg() for child_id in child_ids: logging.debug("child: %s", str(child_id)) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 743bbcb..ec329ca 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -236,7 +236,7 @@ class Node: else: return msgs - def compute_pi_agg(self): + def compute_and_update_pi_agg(self): """ Compute and update pi_agg, the prior probability, given the current state of messages received from parents. @@ -252,7 +252,7 @@ class Node: self.update_pi_agg(factor_product.marginalize(self.parents).values) pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - def compute_lambda_agg(self): + def compute_and_update_lambda_agg(self): """ Compute and update lambda_agg, the likelihood, given the current state of messages received from children. @@ -370,7 +370,7 @@ class BernoulliOrNode(Node): def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) - def compute_pi_agg(self): + def compute_and_update_pi_agg(self): """ Compute and update pi_agg, the prior probability, given the current state of messages received from parents. Sidestep explicit factor product and @@ -424,7 +424,7 @@ class BernoulliAndNode(Node): def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) - def compute_pi_agg(self): + def compute_and_update_pi_agg(self): """ Compute and update pi_agg, the prior probability, given the current state of messages received from parents. Sidestep explicit factor product and -- cgit v1.2.3