aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2018-01-17 15:04:57 -0800
committerCathy Yeh <cathy@driver.xyz>2018-01-17 17:26:38 -0800
commitd92ed9f14baead60fdd6c1d823345cc3ddd1bc04 (patch)
treee171d09610cab395f4eab2354e0c5c09b5d58f25
parent7053fefc6f9e43b1e252d1f551401a7a70b52e93 (diff)
downloadbeliefs-d92ed9f14baead60fdd6c1d823345cc3ddd1bc04.tar.gz
beliefs-d92ed9f14baead60fdd6c1d823345cc3ddd1bc04.tar.bz2
beliefs-d92ed9f14baead60fdd6c1d823345cc3ddd1bc04.zip
consolidate Node methods _init_factor_for_variable, _init_pi_received_msgs, into a single method _init_factors_for_variables
overrides of _init_factor_for_variable for BernoulliOr/AndNode not needed
-rw-r--r--beliefs/models/belief_update_node_model.py79
1 files changed, 25 insertions, 54 deletions
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index 1a9ab19..743bbcb 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -144,11 +144,14 @@ class Node:
self.cardinality = cpd.cardinality[0]
self.cpd = cpd
- self.pi_agg = self._init_factor_for_variable()
- self.lambda_agg = self._init_factor_for_variable()
+ self.pi_agg = self._init_factors_for_variables([self.label_id])[self.label_id]
+ self.lambda_agg = self._init_factors_for_variables([self.label_id])[self.label_id]
+
+ self.pi_received_msgs = self._init_factors_for_variables(self.parents)
+ self.lambda_received_msgs = \
+ {child: self._init_factors_for_variables([self.label_id])[self.label_id]
+ for child in children}
- self.pi_received_msgs = self._init_pi_received_msgs(self.parents)
- self.lambda_received_msgs = {child: self._init_factor_for_variable() for child in children}
@property
def belief(self):
@@ -167,41 +170,33 @@ class Node:
def _normalize(self, value):
return value/value.sum()
- def _init_factor_for_variable(self):
- """
- Returns
- instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of
- ndim 1 and shape (self.cardinality,)
- """
- return DiscreteFactor(variables=[self.cpd.variable],
- cardinality=[self.cardinality],
- values=None,
- state_names=None)
-
- def _init_pi_received_msgs(self, parents):
+ def _init_factors_for_variables(self, variables):
"""
Args
- parents: list,
- list of strings, parent ids of the node
+ variables: list,
+ list of ints/strings, e.g. the single node variable or list
+ of parent ids of the node
Returns
- msgs: dict,
- a dict with key, value pair as {parent_id: instance of a DiscreteFactor},
+ factors: dict,
+ where the dict has key, value pair as {variable_id: instance of a DiscreteFactor},
where DiscreteFactor.values is an np.array of ndim 1 and
- shape (cardinality of parent_id,)
+ shape (cardinality of variable_id,)
"""
- msgs = {}
- for k in parents:
+ variables = list(variables)
+ factors = {}
+
+ for var in variables:
if self.cpd.state_names is not None:
- state_names = {k: self.cpd.state_names[k]}
+ state_names = {var: self.cpd.state_names[var]}
else:
state_names = None
- kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)]
- msgs[k] = DiscreteFactor(variables=[k],
- cardinality=[kth_cardinality],
- values=None,
- state_names=state_names)
- return msgs
+ cardinality = self.cpd.cardinality[self.cpd.variables.index(var)]
+ factors[var] = DiscreteFactor(variables=[var],
+ cardinality=[cardinality],
+ values=None,
+ state_names=state_names)
+ return factors
def _return_msgs_received_for_msg_type(self, message_type):
"""
@@ -375,18 +370,6 @@ class BernoulliOrNode(Node):
def __init__(self, label_id, children, parents):
super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents))
- def _init_factor_for_variable(self):
- """
- Returns
- instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of
- ndim 1 and shape (self.cardinality,)
- """
- variable = self.cpd.variable
- return DiscreteFactor(variables=[self.cpd.variable],
- cardinality=[self.cardinality],
- values=None,
- state_names={variable: self.cpd.state_names[variable]})
-
def compute_pi_agg(self):
"""
Compute and update pi_agg, the prior probability, given the current state
@@ -441,18 +424,6 @@ class BernoulliAndNode(Node):
def __init__(self, label_id, children, parents):
super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents))
- def _init_factor_for_variable(self):
- """
- Returns
- instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of
- ndim 1 and shape (self.cardinality,)
- """
- variable = self.cpd.variable
- return DiscreteFactor(variables=[self.cpd.variable],
- cardinality=[self.cardinality],
- values=None,
- state_names={variable: self.cpd.state_names[variable]})
-
def compute_pi_agg(self):
"""
Compute and update pi_agg, the prior probability, given the current state