import copy
from enum import Enum
import numpy as np
import itertools
from functools import reduce
import networkx as nx
from beliefs.models.base_models import BayesianModel
from beliefs.factors.discrete_factor import DiscreteFactor
from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD
from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD
class InvalidLambdaMsgToParent(Exception):
"""Computed invalid lambda msg to send to parent."""
pass
class MessageType(Enum):
LAMBDA = 'lambda'
PI = 'pi'
class BeliefUpdateNodeModel(BayesianModel):
"""
A Bayesian model storing nodes (e.g. Node or BernoulliOrNode) implementing properties
and methods for Pearl's belief update algorithm.
ref: "Fusion, Propagation, and Structuring in Belief Networks"
Artificial Intelligence 29 (1986) 241-288
"""
def __init__(self, nodes_dict):
"""
Args
nodes_dict: dict
a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
"""
super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()),
variables=list(nodes_dict.keys()),
cpds=[node.cpd for node in nodes_dict.values()])
self.nodes_dict = nodes_dict
@classmethod
def init_from_edges(cls, edges, node_class):
"""
Create model from edges where all nodes are a from the same node class.
Args
edges: list,
list of edge tuples of form [('parent', 'child')]
node_class: Node class or subclass,
class from which to create all the nodes automatically from edges,
e.g. BernoulliAndNode or BernoulliOrNode
"""
nodes = set()
g = nx.DiGraph(edges)
for label in set(itertools.chain(*edges)):
node = node_class(label_id=label,
children=list(g.successors(label)),
parents=list(g.predecessors(label)))
nodes.add(node)
nodes_dict = {node.label_id: node for node in nodes}
return cls(nodes_dict)
@staticmethod
def _get_edges_from_nodes(nodes):
"""
Return list of all directed edges in nodes.
Args
nodes: iterable,
iterable of objects of the Node class or subclass
Returns
edges: list,
list of edge tuples
"""
edges = set()
for node in nodes:
if node.parents:
edge_tuples = zip(node.parents, [node.label_id]*len(node.parents))
edges.update(edge_tuples)
return list(edges)
def set_boundary_conditions(self):
"""
Set boundary conditions for nodes in the model.
1. Root nodes: if x is a node with no parents, set Pi(x) = prior
probability of x.
2. Leaf nodes: if x is a node with no children, set Lambda(x)
to an (unnormalized) unit vector, of length the cardinality of x.
"""
for root in self.get_roots():
self.nodes_dict[root].update_pi_agg(self.nodes_dict[root].cpd.values)
for leaf in self.get_leaves():
self.nodes_dict[leaf].update_lambda_agg(np.ones([self.nodes_dict[leaf].cardinality]))
@property
def all_nodes_are_fully_initialized(self):
"""
Check if all nodes in the model are initialized, i.e. lambda and pi messages and
lambda_agg and pi_agg are not None for every node.
Returns
bool, True if all nodes in the model are initialized, else False.
"""
for node in self.nodes_dict.values():
if not node.is_fully_initialized:
return False
return True
def copy(self):
"""Return a copy of the model."""
copy_nodes = copy.deepcopy(self.nodes_dict)
copy_model = self.__class__(nodes_dict=copy_nodes)
return copy_model
class Node:
"""
A node in a DAG with methods to compute the belief (marginal probability of
the node given evidence) and compute pi/lambda messages to/from its neighbors
to update its belief.
Implemented from Pearl's belief propagation algorithm for polytrees.
"""
def __init__(self, children, cpd):
"""
Args
children: list,
list of strings
cpd: an instance of TabularCPD or one of its subclasses,
e.g. BernoulliOrCPD or BernoulliAndCPD
"""
self.label_id = cpd.variable
self.children = children
self.parents = cpd.parents
self.cardinality = cpd.cardinality[0]
self.cpd = cpd
self.pi_agg = self._init_factors_for_variables([self.label_id])[self.label_id]
self.lambda_agg = self._init_factors_for_variables([self.label_id])[self.label_id]
self.pi_received_msgs = self._init_factors_for_variables(self.parents)
self.lambda_received_msgs = \
{child: self._init_factors_for_variables([self.label_id])[self.label_id]
for child in children}
@property
def belief(self):
"""
Calculate the marginal probability of the variable from its aggregate values.
Returns
belief, an np.array of ndim 1 and shape (self.cardinality,)
"""
if any(self.pi_agg.values) and any(self.lambda_agg.values):
belief = (self.lambda_agg * self.pi_agg).values
return self._normalize(belief)
else:
return None
def _normalize(self, value):
return value/value.sum()
def _init_factors_for_variables(self, variables):
"""
Args
variables: list,
list of ints/strings, e.g. the single node variable or list
of parent ids of the node
Returns
factors: dict,
where the dict has key, value pair as {variable_id: instance of a DiscreteFactor},
where DiscreteFactor.values is an np.array of ndim 1 and
shape (cardinality of variable_id,)
"""
variables = list(variables)
factors = {}
for var in variables:
if self.cpd.state_names is not None:
state_names = {var: self.cpd.state_names[var]}
else:
state_names = None
cardinality = self.cpd.cardinality[self.cpd.variables.index(var)]
factors[var] = DiscreteFactor(variables=[var],
cardinality=[cardinality],
values=None,
state_names=state_names)
return factors
def _return_msgs_received_for_msg_type(self, message_type):
"""
Args
message_type: MessageType enum
Returns
msg_values: list,
list of DiscreteFactors with property `values` containing
the values of the messages (np.arrays)
"""
if message_type == MessageType.LAMBDA:
msgs = [msg for msg in self.lambda_received_msgs.values()]
elif message_type == MessageType.PI:
msgs = [msg for msg in self.pi_received_msgs.values()]
return msgs
def validate_and_return_msgs_received_for_msg_type(self, message_type):
"""
Check that all messages have been received from children (parents).
Raise error if all messages have not been received. Called
before calculating lambda_agg (pi_agg).
Args
message_type: MessageType enum
Returns
msgs: list,
list of DiscreteFactors with property `values` containing
the values of the messages (np.arrays)
"""
msgs = self._return_msgs_received_for_msg_type(message_type)
if any(msg.values is None for msg in msgs):
raise ValueError(
"Missing value for {msg_type} msg from child: can't compute {msg_type}_agg."
.format(msg_type=message_type.value)
)
else:
return msgs
def compute_and_update_pi_agg(self):
"""
Compute and update pi_agg, the prior probability, given the current state
of messages received from parents.
"""
if len(self.parents) == 0:
self.update_pi_agg(self.cpd.values)
else:
factors_to_multiply = [self.cpd]
pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
factors_to_multiply.extend(pi_msgs)
factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply)
self.update_pi_agg(factor_product.marginalize(self.parents).values)
pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
def compute_and_update_lambda_agg(self):
"""
Compute and update lambda_agg, the likelihood, given the current state
of messages received from children.
"""
if len(self.children) != 0:
lambda_msg_values = [
msg.values for msg in
self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
]
self.update_lambda_agg(reduce(np.multiply, lambda_msg_values))
def update_pi_agg(self, new_value):
self.pi_agg.update_values(new_value)
def update_lambda_agg(self, new_value):
self.lambda_agg.update_values(new_value)
def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type):
if key not in received_msg_dict.keys():
raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}"
.format(key, received_msg_dict.keys()))
if not isinstance(new_value, np.ndarray):
raise TypeError("Expected a new value of type numpy.ndarray, but got type {}"
.format(type(new_value)))
if message_type == MessageType.LAMBDA:
expected_shape = (self.cardinality,)
elif message_type == MessageType.PI:
expected_shape = (self.cpd.cardinality[self.cpd.variables.index(key)],)
if new_value.shape != expected_shape:
raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
.format(expected_shape, new_value.shape))
received_msg_dict[key].update_values(new_value)
def update_pi_msg_from_parent(self, parent, new_value):
self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
key=parent,
new_value=new_value,
message_type=MessageType.PI)
def update_lambda_msg_from_child(self, child, new_value):
self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs,
key=child,
new_value=new_value,
message_type=MessageType.LAMBDA)
def compute_pi_msg_to_child(self, child_k):
"""
Compute pi_msg to child.
Args
child_k: string or int,
the label_id of the child receiving the pi_msg
Returns
np.array of ndim 1 and shape (self.cardinality,)
"""
lambda_msg_from_child = self.lambda_received_msgs[child_k].values
if lambda_msg_from_child is not None:
with np.errstate(divide='ignore', invalid='ignore'):
# 0/0 := 0
return self._normalize(
np.nan_to_num(np.divide(self.belief, lambda_msg_from_child)))
else:
raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.")
def compute_lambda_msg_to_parent(self, parent_k):
"""
Compute lambda_msg to parent.
Args
parent_k: string or int,
the label_id of the parent receiving the lambda_msg
Returns
np.array of ndim 1 and shape (cardinality of parent_k,)
"""
if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
return np.ones([self.cardinality])
else:
factors_to_multiply = [self.cpd]
pi_msgs_excl_k = [msg for par_id, msg in self.pi_received_msgs.items()
if par_id != parent_k]
factors_to_multiply.extend(pi_msgs_excl_k)
factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply)
new_factor = factor_product.marginalize(list(set(self.parents) - set([parent_k])))
lambda_msg_to_k = (self.lambda_agg * new_factor).marginalize([self.lambda_agg.variables[0]])
return self._normalize(lambda_msg_to_k.values)
@property
def is_fully_initialized(self):
"""
Returns True if all lambda and pi messages and lambda_agg and
pi_agg are not None, else False.
"""
lambda_msgs = self._return_msgs_received_for_msg_type(MessageType.LAMBDA)
if any(msg is None for msg in lambda_msgs):
return False
pi_msgs = self._return_msgs_received_for_msg_type(MessageType.PI)
if any(msg is None for msg in pi_msgs):
return False
if (self.pi_agg.values is None) or (self.lambda_agg.values is None):
return False
return True
class BernoulliOrNode(Node):
"""
A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True']
and conditional probability distribution described by 'Or' logic.
"""
def __init__(self, label_id, children, parents):
super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents))
def compute_and_update_pi_agg(self):
"""
Compute and update pi_agg, the prior probability, given the current state
of messages received from parents. Sidestep explicit factor product and
marginalization.
"""
if len(self.parents) == 0:
self.update_pi_agg(self.cpd.values)
else:
pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
parents_p0 = [p.get_value_for_state_vector({p.variables[0]: 'False'})
for p in pi_msgs]
p_0 = reduce(lambda x, y: x*y, parents_p0)
p_1 = 1 - p_0
self.update_pi_agg(np.array([p_0, p_1]))
def compute_lambda_msg_to_parent(self, parent_k):
"""
Compute lambda_msg to parent. Sidestep explicit factor product and
marginalization.
Args
parent_k: string or int,
the label_id of the parent receiving the lambda_msg
Returns
np.array of ndim 1 and shape (cardinality of parent_k,)
"""
if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
return np.ones([self.cardinality])
else:
# TODO: cleanup this validation
_ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
p0_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'False'})
for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'})
lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'})
lambda_0 = lambda_agg_1 + (lambda_agg_0 - lambda_agg_1)*p0_product
lambda_1 = lambda_agg_1
lambda_msg = np.array([lambda_0, lambda_1])
if not any(lambda_msg):
raise InvalidLambdaMsgToParent
return self._normalize(lambda_msg)
class BernoulliAndNode(Node):
"""
A node in a DAG associated with a Bernoulli random variable with state_names ['False', 'True']
and conditional probability distribution described by 'And' logic.
"""
def __init__(self, label_id, children, parents):
super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents))
def compute_and_update_pi_agg(self):
"""
Compute and update pi_agg, the prior probability, given the current state
of messages received from parents. Sidestep explicit factor product and
marginalization.
"""
if len(self.parents) == 0:
self.update_pi_agg(self.cpd.values)
else:
pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
parents_p1 = [p.get_value_for_state_vector({p.variables[0]: 'True'})
for p in pi_msgs]
p_1 = reduce(lambda x, y: x*y, parents_p1)
p_0 = 1 - p_1
self.update_pi_agg(np.array([p_0, p_1]))
def compute_lambda_msg_to_parent(self, parent_k):
"""
Compute lambda_msg to parent. Sidestep explicit factor product and
marginalization.
Args
parent_k: string or int,
the label_id of the parent receiving the lambda_msg
Returns
np.array of ndim 1 and shape (cardinality of parent_k,)
"""
if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
return np.ones([self.cardinality])
else:
# TODO: cleanup this validation
_ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
p1_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'True'})
for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1)
lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'})
lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'})
lambda_0 = lambda_agg_0
lambda_1 = lambda_agg_0 + (lambda_agg_1 - lambda_agg_0)*p1_product
lambda_msg = np.array([lambda_0, lambda_1])
if not any(lambda_msg):
raise InvalidLambdaMsgToParent
return self._normalize(lambda_msg)