import numpy as np
from functools import reduce
from enum import Enum
class InvalidLambdaMsgToParent(Exception):
"""Computed invalid lambda msg to send to parent."""
pass
class MessageType(Enum):
LAMBDA = 'lambda'
PI = 'pi'
class Node:
"""A node in a DAG with methods to compute the belief (marginal probability
of the node given evidence) and compute pi/lambda messages to/from its neighbors
to update its belief.
Implemented from Pearl's belief propagation algorithm.
"""
def __init__(self,
label_id,
children,
parents,
cardinality,
cpd):
"""
Args
label_id: str
children: set of strings
parents: set of strings
cardinality: int, cardinality of the random variable the node represents
cpd: an instance of a conditional probability distribution,
e.g. BernoulliOrCPD or TabularCPD
"""
self.label_id = label_id
self.children = children
self.parents = parents
self.cardinality = cardinality
self.cpd = cpd
self.pi_agg = None # np.array dimensions [1, cardinality]
self.lambda_agg = None # np.array dimensions [1, cardinality]
self.pi_received_msgs = self._init_received_msgs(parents)
self.lambda_received_msgs = self._init_received_msgs(children)
@classmethod
def from_cpd_class(cls,
label_id,
children,
parents,
cardinality,
cpd_class):
cpd = cpd_class(label_id, parents)
return cls(label_id, children, parents, cardinality, cpd)
@property
def belief(self):
if self.pi_agg.any() and self.lambda_agg.any():
belief = np.multiply(self.pi_agg, self.lambda_agg)
return self._normalize(belief)
else:
return None
def _normalize(self, value):
return value/value.sum()
@staticmethod
def _init_received_msgs(keys):
return {k: None for k in keys}
def _return_msgs_received_for_msg_type(self, message_type):
"""
Input:
message_type: MessageType enum
Returns:
msg_values: list of message values (each an np.array)
"""
if message_type == MessageType.LAMBDA:
msg_values = [msg for msg in self.lambda_received_msgs.values()]
elif message_type == MessageType.PI:
msg_values = [msg for msg in self.pi_received_msgs.values()]
return msg_values
def validate_and_return_msgs_received_for_msg_type(self, message_type):
"""
Check that all messages have been received from children (parents).
Raise error if all messages have not been received. Called
before calculating lambda_agg (pi_agg).
Input:
message_type: MessageType enum
Returns:
msg_values: list of message values (each an np.array)
"""
msg_values = self._return_msgs_received_for_msg_type(message_type)
if any(msg is None for msg in msg_values):
raise ValueError(
"Missing value for {msg_type} msg from child: can't compute {msg_type}_agg.".
format(msg_type=message_type.value)
)
else:
return msg_values
def compute_pi_agg(self):
# TODO: implement explict factor product operation
raise NotImplementedError
def compute_lambda_agg(self):
if not self.children:
return self.lambda_agg
else:
lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
self.lambda_agg = reduce(np.multiply, lambda_msg_values)
return self.lambda_agg
def _update_received_msg_by_key(self, received_msg_dict, key, new_value):
if key not in received_msg_dict.keys():
raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}".
format(key, received_msg_dict.keys()))
if not isinstance(new_value, np.ndarray):
raise TypeError("Expected a new value of type numpy.ndarray, but got type {}".
format(type(new_value)))
if new_value.shape != (self.cardinality,):
raise ValueError("Expected new value to be of dimensions ({},) but got {} instead".
format(self.cardinality, new_value.shape))
received_msg_dict[key] = new_value
def update_pi_msg_from_parent(self, parent, new_value):
self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
key=parent,
new_value=new_value)
def update_lambda_msg_from_child(self, child, new_value):
self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs,
key=child,
new_value=new_value)
def compute_pi_msg_to_child(self, child_k):
lambda_msg_from_child = self.lambda_received_msgs[child_k]
if lambda_msg_from_child is not None:
with np.errstate(divide='ignore', invalid='ignore'):
# 0/0 := 0
return self._normalize(
np.nan_to_num(np.divide(self.belief, lambda_msg_from_child)))
else:
raise ValueError("Can't compute pi message to child_{} without having received" \
"a lambda message from that child.")
def compute_lambda_msg_to_parent(self, parent_k):
# TODO: implement explict factor product operation
raise NotImplementedError
@property
def is_fully_initialized(self):
"""
Returns True if all lambda and pi messages and lambda_agg and
pi_agg are not None, else False.
"""
lambda_msgs = self._return_msgs_received_for_msg_type(MessageType.LAMBDA)
if any(msg is None for msg in lambda_msgs):
return False
pi_msgs = self._return_msgs_received_for_msg_type(MessageType.PI)
if any(msg is None for msg in pi_msgs):
return False
if (self.pi_agg is None) or (self.lambda_agg is None):
return False
return True