From d92ed9f14baead60fdd6c1d823345cc3ddd1bc04 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Wed, 17 Jan 2018 15:04:57 -0800 Subject: consolidate Node methods _init_factor_for_variable, _init_pi_received_msgs, into a single method _init_factors_for_variables overrides of _init_factor_for_variable for BernoulliOr/AndNode not needed --- beliefs/models/belief_update_node_model.py | 79 ++++++++++-------------------- 1 file changed, 25 insertions(+), 54 deletions(-) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 1a9ab19..743bbcb 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -144,11 +144,14 @@ class Node: self.cardinality = cpd.cardinality[0] self.cpd = cpd - self.pi_agg = self._init_factor_for_variable() - self.lambda_agg = self._init_factor_for_variable() + self.pi_agg = self._init_factors_for_variables([self.label_id])[self.label_id] + self.lambda_agg = self._init_factors_for_variables([self.label_id])[self.label_id] + + self.pi_received_msgs = self._init_factors_for_variables(self.parents) + self.lambda_received_msgs = \ + {child: self._init_factors_for_variables([self.label_id])[self.label_id] + for child in children} - self.pi_received_msgs = self._init_pi_received_msgs(self.parents) - self.lambda_received_msgs = {child: self._init_factor_for_variable() for child in children} @property def belief(self): @@ -167,41 +170,33 @@ class Node: def _normalize(self, value): return value/value.sum() - def _init_factor_for_variable(self): - """ - Returns - instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of - ndim 1 and shape (self.cardinality,) - """ - return DiscreteFactor(variables=[self.cpd.variable], - cardinality=[self.cardinality], - values=None, - state_names=None) - - def _init_pi_received_msgs(self, parents): + def _init_factors_for_variables(self, variables): """ Args - parents: list, - list of strings, parent ids of the node + variables: list, + list of ints/strings, e.g. the single node variable or list + of parent ids of the node Returns - msgs: dict, - a dict with key, value pair as {parent_id: instance of a DiscreteFactor}, + factors: dict, + where the dict has key, value pair as {variable_id: instance of a DiscreteFactor}, where DiscreteFactor.values is an np.array of ndim 1 and - shape (cardinality of parent_id,) + shape (cardinality of variable_id,) """ - msgs = {} - for k in parents: + variables = list(variables) + factors = {} + + for var in variables: if self.cpd.state_names is not None: - state_names = {k: self.cpd.state_names[k]} + state_names = {var: self.cpd.state_names[var]} else: state_names = None - kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)] - msgs[k] = DiscreteFactor(variables=[k], - cardinality=[kth_cardinality], - values=None, - state_names=state_names) - return msgs + cardinality = self.cpd.cardinality[self.cpd.variables.index(var)] + factors[var] = DiscreteFactor(variables=[var], + cardinality=[cardinality], + values=None, + state_names=state_names) + return factors def _return_msgs_received_for_msg_type(self, message_type): """ @@ -375,18 +370,6 @@ class BernoulliOrNode(Node): def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) - def _init_factor_for_variable(self): - """ - Returns - instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of - ndim 1 and shape (self.cardinality,) - """ - variable = self.cpd.variable - return DiscreteFactor(variables=[self.cpd.variable], - cardinality=[self.cardinality], - values=None, - state_names={variable: self.cpd.state_names[variable]}) - def compute_pi_agg(self): """ Compute and update pi_agg, the prior probability, given the current state @@ -441,18 +424,6 @@ class BernoulliAndNode(Node): def __init__(self, label_id, children, parents): super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) - def _init_factor_for_variable(self): - """ - Returns - instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of - ndim 1 and shape (self.cardinality,) - """ - variable = self.cpd.variable - return DiscreteFactor(variables=[self.cpd.variable], - cardinality=[self.cardinality], - values=None, - state_names={variable: self.cpd.state_names[variable]}) - def compute_pi_agg(self): """ Compute and update pi_agg, the prior probability, given the current state -- cgit v1.2.3