aboutsummaryrefslogtreecommitdiff
path: root/beliefs/models/belief_update_node_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/models/belief_update_node_model.py')
-rw-r--r--beliefs/models/belief_update_node_model.py238
1 files changed, 160 insertions, 78 deletions
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index 17e98fa..1a9ab19 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -33,9 +33,9 @@ class BeliefUpdateNodeModel(BayesianModel):
"""
def __init__(self, nodes_dict):
"""
- Input:
- nodes_dict: dict
- a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
+ Args
+ nodes_dict: dict
+ a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
"""
super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()),
variables=list(nodes_dict.keys()),
@@ -45,12 +45,15 @@ class BeliefUpdateNodeModel(BayesianModel):
@classmethod
def init_from_edges(cls, edges, node_class):
- """Create nodes from the same node class.
+ """
+ Create model from edges where all nodes are a from the same node class.
- Input:
- edges: list of edge tuples of form ('parent', 'child')
- node_class: the Node class or subclass from which to
- create all the nodes from edges.
+ Args
+ edges: list,
+ list of edge tuples of form [('parent', 'child')]
+ node_class: Node class or subclass,
+ class from which to create all the nodes automatically from edges,
+ e.g. BernoulliAndNode or BernoulliOrNode
"""
nodes = set()
g = nx.DiGraph(edges)
@@ -68,10 +71,12 @@ class BeliefUpdateNodeModel(BayesianModel):
"""
Return list of all directed edges in nodes.
- Args:
- nodes: an iterable of objects of the Node class or subclass
- Returns:
- edges: list of edge tuples
+ Args
+ nodes: iterable,
+ iterable of objects of the Node class or subclass
+ Returns
+ edges: list,
+ list of edge tuples
"""
edges = set()
for node in nodes:
@@ -82,11 +87,13 @@ class BeliefUpdateNodeModel(BayesianModel):
def set_boundary_conditions(self):
"""
- 1. Root nodes: if x is a node with no parents, set Pi(x) = prior
- probability of x.
+ Set boundary conditions for nodes in the model.
+
+ 1. Root nodes: if x is a node with no parents, set Pi(x) = prior
+ probability of x.
- 2. Leaf nodes: if x is a node with no children, set Lambda(x)
- to an (unnormalized) unit vector, of length the cardinality of x.
+ 2. Leaf nodes: if x is a node with no children, set Lambda(x)
+ to an (unnormalized) unit vector, of length the cardinality of x.
"""
for root in self.get_roots():
self.nodes_dict[root].update_pi_agg(self.nodes_dict[root].cpd.values)
@@ -97,8 +104,11 @@ class BeliefUpdateNodeModel(BayesianModel):
@property
def all_nodes_are_fully_initialized(self):
"""
- Returns True if, for all nodes in the model, all lambda and pi
- messages and lambda_agg and pi_agg are not None, else False.
+ Check if all nodes in the model are initialized, i.e. lambda and pi messages and
+ lambda_agg and pi_agg are not None for every node.
+
+ Returns
+ bool, True if all nodes in the model are initialized, else False.
"""
for node in self.nodes_dict.values():
if not node.is_fully_initialized:
@@ -106,27 +116,27 @@ class BeliefUpdateNodeModel(BayesianModel):
return True
def copy(self):
- """
- Returns a copy of the model.
- """
+ """Return a copy of the model."""
copy_nodes = copy.deepcopy(self.nodes_dict)
copy_model = self.__class__(nodes_dict=copy_nodes)
return copy_model
class Node:
- """A node in a DAG with methods to compute the belief (marginal probability
- of the node given evidence) and compute pi/lambda messages to/from its neighbors
+ """
+ A node in a DAG with methods to compute the belief (marginal probability of
+ the node given evidence) and compute pi/lambda messages to/from its neighbors
to update its belief.
- Implemented from Pearl's belief propagation algorithm.
+ Implemented from Pearl's belief propagation algorithm for polytrees.
"""
def __init__(self, children, cpd):
"""
Args
- children: list of strings
- cpd: an instance of a conditional probability distribution,
- e.g. BernoulliOrCPD or TabularCPD
+ children: list,
+ list of strings
+ cpd: an instance of TabularCPD or one of its subclasses,
+ e.g. BernoulliOrCPD or BernoulliAndCPD
"""
self.label_id = cpd.variable
self.children = children
@@ -134,15 +144,20 @@ class Node:
self.cardinality = cpd.cardinality[0]
self.cpd = cpd
- # instances of DiscreteFactor with `values` an np.array of dimensions [1, cardinality]
- self.pi_agg = self._init_aggregate_values()
- self.lambda_agg = self._init_aggregate_values()
+ self.pi_agg = self._init_factor_for_variable()
+ self.lambda_agg = self._init_factor_for_variable()
self.pi_received_msgs = self._init_pi_received_msgs(self.parents)
- self.lambda_received_msgs = {child: self._init_aggregate_values() for child in children}
+ self.lambda_received_msgs = {child: self._init_factor_for_variable() for child in children}
@property
def belief(self):
+ """
+ Calculate the marginal probability of the variable from its aggregate values.
+
+ Returns
+ belief, an np.array of ndim 1 and shape (self.cardinality,)
+ """
if any(self.pi_agg.values) and any(self.lambda_agg.values):
belief = (self.lambda_agg * self.pi_agg).values
return self._normalize(belief)
@@ -152,29 +167,50 @@ class Node:
def _normalize(self, value):
return value/value.sum()
- def _init_aggregate_values(self):
+ def _init_factor_for_variable(self):
+ """
+ Returns
+ instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of
+ ndim 1 and shape (self.cardinality,)
+ """
return DiscreteFactor(variables=[self.cpd.variable],
cardinality=[self.cardinality],
values=None,
state_names=None)
def _init_pi_received_msgs(self, parents):
+ """
+ Args
+ parents: list,
+ list of strings, parent ids of the node
+ Returns
+ msgs: dict,
+ a dict with key, value pair as {parent_id: instance of a DiscreteFactor},
+ where DiscreteFactor.values is an np.array of ndim 1 and
+ shape (cardinality of parent_id,)
+ """
msgs = {}
for k in parents:
+ if self.cpd.state_names is not None:
+ state_names = {k: self.cpd.state_names[k]}
+ else:
+ state_names = None
+
kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)]
msgs[k] = DiscreteFactor(variables=[k],
cardinality=[kth_cardinality],
values=None,
- state_names=None)
+ state_names=state_names)
return msgs
def _return_msgs_received_for_msg_type(self, message_type):
"""
- Input:
- message_type: MessageType enum
-
- Returns:
- msg_values: list of DiscreteFactors containing message values (np.arrays)
+ Args
+ message_type: MessageType enum
+ Returns
+ msg_values: list,
+ list of DiscreteFactors with property `values` containing
+ the values of the messages (np.arrays)
"""
if message_type == MessageType.LAMBDA:
msgs = [msg for msg in self.lambda_received_msgs.values()]
@@ -188,11 +224,12 @@ class Node:
Raise error if all messages have not been received. Called
before calculating lambda_agg (pi_agg).
- Input:
- message_type: MessageType enum
-
- Returns:
- msgs: list of DiscreteFactors containing message values (np.array)
+ Args
+ message_type: MessageType enum
+ Returns
+ msgs: list,
+ list of DiscreteFactors with property `values` containing
+ the values of the messages (np.arrays)
"""
msgs = self._return_msgs_received_for_msg_type(message_type)
@@ -205,6 +242,10 @@ class Node:
return msgs
def compute_pi_agg(self):
+ """
+ Compute and update pi_agg, the prior probability, given the current state
+ of messages received from parents.
+ """
if len(self.parents) == 0:
self.update_pi_agg(self.cpd.values)
else:
@@ -217,6 +258,10 @@ class Node:
pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
def compute_lambda_agg(self):
+ """
+ Compute and update lambda_agg, the likelihood, given the current state
+ of messages received from children.
+ """
if len(self.children) != 0:
lambda_msg_values = [
msg.values for msg in
@@ -245,9 +290,8 @@ class Node:
expected_shape = (self.cpd.cardinality[self.cpd.variables.index(key)],)
if new_value.shape != expected_shape:
- raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
- .format(expected_shape, new_value.shape))
- # received_msg_dict[key]._values = new_value
+ raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
+ .format(expected_shape, new_value.shape))
received_msg_dict[key].update_values(new_value)
def update_pi_msg_from_parent(self, parent, new_value):
@@ -263,6 +307,15 @@ class Node:
message_type=MessageType.LAMBDA)
def compute_pi_msg_to_child(self, child_k):
+ """
+ Compute pi_msg to child.
+
+ Args
+ child_k: string or int,
+ the label_id of the child receiving the pi_msg
+ Returns
+ np.array of ndim 1 and shape (self.cardinality,)
+ """
lambda_msg_from_child = self.lambda_received_msgs[child_k].values
if lambda_msg_from_child is not None:
with np.errstate(divide='ignore', invalid='ignore'):
@@ -273,6 +326,15 @@ class Node:
raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.")
def compute_lambda_msg_to_parent(self, parent_k):
+ """
+ Compute lambda_msg to parent.
+
+ Args
+ parent_k: string or int,
+ the label_id of the parent receiving the lambda_msg
+ Returns
+ np.array of ndim 1 and shape (cardinality of parent_k,)
+ """
if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
return np.ones([self.cardinality])
else:
@@ -306,30 +368,31 @@ class Node:
class BernoulliOrNode(Node):
- def __init__(self,
- label_id,
- children,
- parents):
+ """
+ A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True']
+ and conditional probability distribution described by 'Or' logic.
+ """
+ def __init__(self, label_id, children, parents):
super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents))
- def _init_aggregate_values(self):
+ def _init_factor_for_variable(self):
+ """
+ Returns
+ instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of
+ ndim 1 and shape (self.cardinality,)
+ """
variable = self.cpd.variable
return DiscreteFactor(variables=[self.cpd.variable],
cardinality=[self.cardinality],
values=None,
state_names={variable: self.cpd.state_names[variable]})
- def _init_pi_received_msgs(self, parents):
- msgs = {}
- for k in parents:
- kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)]
- msgs[k] = DiscreteFactor(variables=[k],
- cardinality=[kth_cardinality],
- values=None,
- state_names={k: self.cpd.state_names[k]})
- return msgs
-
def compute_pi_agg(self):
+ """
+ Compute and update pi_agg, the prior probability, given the current state
+ of messages received from parents. Sidestep explicit factor product and
+ marginalization.
+ """
if len(self.parents) == 0:
self.update_pi_agg(self.cpd.values)
else:
@@ -339,9 +402,18 @@ class BernoulliOrNode(Node):
p_0 = reduce(lambda x, y: x*y, parents_p0)
p_1 = 1 - p_0
self.update_pi_agg(np.array([p_0, p_1]))
- return self.pi_agg
def compute_lambda_msg_to_parent(self, parent_k):
+ """
+ Compute lambda_msg to parent. Sidestep explicit factor product and
+ marginalization.
+
+ Args
+ parent_k: string or int,
+ the label_id of the parent receiving the lambda_msg
+ Returns
+ np.array of ndim 1 and shape (cardinality of parent_k,)
+ """
if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
return np.ones([self.cardinality])
else:
@@ -362,30 +434,31 @@ class BernoulliOrNode(Node):
class BernoulliAndNode(Node):
- def __init__(self,
- label_id,
- children,
- parents):
+ """
+ A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True']
+ and conditional probability distribution described by 'And' logic.
+ """
+ def __init__(self, label_id, children, parents):
super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents))
- def _init_aggregate_values(self):
+ def _init_factor_for_variable(self):
+ """
+ Returns
+ instance of a DiscreteFactor, where DiscreteFactor.values is an np.array of
+ ndim 1 and shape (self.cardinality,)
+ """
variable = self.cpd.variable
return DiscreteFactor(variables=[self.cpd.variable],
cardinality=[self.cardinality],
values=None,
state_names={variable: self.cpd.state_names[variable]})
- def _init_pi_received_msgs(self, parents):
- msgs = {}
- for k in parents:
- kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)]
- msgs[k] = DiscreteFactor(variables=[k],
- cardinality=[kth_cardinality],
- values=None,
- state_names={k: self.cpd.state_names[k]})
- return msgs
-
def compute_pi_agg(self):
+ """
+ Compute and update pi_agg, the prior probability, given the current state
+ of messages received from parents. Sidestep explicit factor product and
+ marginalization.
+ """
if len(self.parents) == 0:
self.update_pi_agg(self.cpd.values)
else:
@@ -395,9 +468,18 @@ class BernoulliAndNode(Node):
p_1 = reduce(lambda x, y: x*y, parents_p1)
p_0 = 1 - p_1
self.update_pi_agg(np.array([p_0, p_1]))
- return self.pi_agg
def compute_lambda_msg_to_parent(self, parent_k):
+ """
+ Compute lambda_msg to parent. Sidestep explicit factor product and
+ marginalization.
+
+ Args
+ parent_k: string or int,
+ the label_id of the parent receiving the lambda_msg
+ Returns
+ np.array of ndim 1 and shape (cardinality of parent_k,)
+ """
if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
return np.ones([self.cardinality])
else: