From 76090e3f03c01e208d41203a6286ea432714090a Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 19:06:30 -0800 Subject: clean up node class, simpler initialization --- beliefs/models/belief_update_node_model.py | 42 ++++++------------------------ 1 file changed, 8 insertions(+), 34 deletions(-) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 1765ed9..820ee0c 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -121,44 +121,26 @@ class Node: Implemented from Pearl's belief propagation algorithm. """ - def __init__(self, - label_id, - children, - parents, - cardinality, - cpd): + def __init__(self, children, cpd): """ Args - label_id: str - children: set of strings - parents: set of strings - cardinality: int, cardinality of the random variable the node represents + children: list of strings cpd: an instance of a conditional probability distribution, e.g. BernoulliOrCPD or TabularCPD """ - self.label_id = label_id # this can be obtained from cpd.variable + self.label_id = cpd.variable self.children = children - self.parents = parents # this can be obtained from cpd.variables[1:] - self.cardinality = cardinality # this can be obtained from cpd.cardinality[0] + self.parents = cpd.parents + self.cardinality = cpd.cardinality[0] self.cpd = cpd # instances of DiscreteFactor with `values` an np.array of dimensions [1, cardinality] self.pi_agg = self._init_aggregate_values() self.lambda_agg = self._init_aggregate_values() - self.pi_received_msgs = self._init_pi_received_msgs(parents) + self.pi_received_msgs = self._init_pi_received_msgs(self.parents) self.lambda_received_msgs = {child: self._init_aggregate_values() for child in children} - @classmethod - def from_cpd_class(cls, - label_id, - children, - parents, - cardinality, - cpd_class): - cpd = cpd_class(label_id, parents) - return cls(label_id, children, parents, cardinality, cpd) - @property def belief(self): if any(self.pi_agg.values) and any(self.lambda_agg.values): @@ -311,11 +293,7 @@ class BernoulliOrNode(Node): label_id, children, parents): - super().__init__(label_id=label_id, - children=children, - parents=parents, - cardinality=2, - cpd=BernoulliOrCPD(label_id, parents)) + super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents)) def compute_pi_agg(self): if len(self.parents) == 0: @@ -351,11 +329,7 @@ class BernoulliAndNode(Node): label_id, children, parents): - super().__init__(label_id=label_id, - children=children, - parents=parents, - cardinality=2, - cpd=BernoulliAndCPD(label_id, parents)) + super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents)) def compute_pi_agg(self): if len(self.parents) == 0: -- cgit v1.2.3