aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-12 19:06:30 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-13 18:45:03 -0800
commit76090e3f03c01e208d41203a6286ea432714090a (patch)
treef039aa53429fb74ac39217f181231eb372098c9d
parent70bdf07d25f41de1a9510b64267bfa29791760c7 (diff)
downloadbeliefs-76090e3f03c01e208d41203a6286ea432714090a.tar.gz
beliefs-76090e3f03c01e208d41203a6286ea432714090a.tar.bz2
beliefs-76090e3f03c01e208d41203a6286ea432714090a.zip
clean up node class, simpler initialization
-rw-r--r--beliefs/models/belief_update_node_model.py42
1 files changed, 8 insertions, 34 deletions
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index 1765ed9..820ee0c 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -121,44 +121,26 @@ class Node:
Implemented from Pearl's belief propagation algorithm.
"""
- def __init__(self,
- label_id,
- children,
- parents,
- cardinality,
- cpd):
+ def __init__(self, children, cpd):
"""
Args
- label_id: str
- children: set of strings
- parents: set of strings
- cardinality: int, cardinality of the random variable the node represents
+ children: list of strings
cpd: an instance of a conditional probability distribution,
e.g. BernoulliOrCPD or TabularCPD
"""
- self.label_id = label_id # this can be obtained from cpd.variable
+ self.label_id = cpd.variable
self.children = children
- self.parents = parents # this can be obtained from cpd.variables[1:]
- self.cardinality = cardinality # this can be obtained from cpd.cardinality[0]
+ self.parents = cpd.parents
+ self.cardinality = cpd.cardinality[0]
self.cpd = cpd
# instances of DiscreteFactor with `values` an np.array of dimensions [1, cardinality]
self.pi_agg = self._init_aggregate_values()
self.lambda_agg = self._init_aggregate_values()
- self.pi_received_msgs = self._init_pi_received_msgs(parents)
+ self.pi_received_msgs = self._init_pi_received_msgs(self.parents)
self.lambda_received_msgs = {child: self._init_aggregate_values() for child in children}
- @classmethod
- def from_cpd_class(cls,
- label_id,
- children,
- parents,
- cardinality,
- cpd_class):
- cpd = cpd_class(label_id, parents)
- return cls(label_id, children, parents, cardinality, cpd)
-
@property
def belief(self):
if any(self.pi_agg.values) and any(self.lambda_agg.values):
@@ -311,11 +293,7 @@ class BernoulliOrNode(Node):
label_id,
children,
parents):
- super().__init__(label_id=label_id,
- children=children,
- parents=parents,
- cardinality=2,
- cpd=BernoulliOrCPD(label_id, parents))
+ super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents))
def compute_pi_agg(self):
if len(self.parents) == 0:
@@ -351,11 +329,7 @@ class BernoulliAndNode(Node):
label_id,
children,
parents):
- super().__init__(label_id=label_id,
- children=children,
- parents=parents,
- cardinality=2,
- cpd=BernoulliAndCPD(label_id, parents))
+ super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents))
def compute_pi_agg(self):
if len(self.parents) == 0: