diff options
Diffstat (limited to 'beliefs/types/Node.py')
-rw-r--r-- | beliefs/types/Node.py | 179 |
1 files changed, 179 insertions, 0 deletions
diff --git a/beliefs/types/Node.py b/beliefs/types/Node.py new file mode 100644 index 0000000..a8dca7c --- /dev/null +++ b/beliefs/types/Node.py @@ -0,0 +1,179 @@ +import numpy as np +from functools import reduce +from enum import Enum + + +class InvalidLambdaMsgToParent(Exception): + """Computed invalid lambda msg to send to parent.""" + pass + + +class MessageType(Enum): + LAMBDA = 'lambda' + PI = 'pi' + + +class Node: + """A node in a DAG with methods to compute the belief (marginal probability + of the node given evidence) and compute pi/lambda messages to/from its neighbors + to update its belief. + + Implemented from Pearl's belief propagation algorithm. + """ + def __init__(self, + label_id, + children, + parents, + cardinality, + cpd): + """ + Input: + label_id: str + children: str + parents: set of strings + cardinality: int, cardinality of the random variable the node represents + cpd: an instance of a conditional probability distribution, + e.g. BernoulliOrFactor or pgmpy's TabularCPD + """ + self.label_id = label_id + self.children = children + self.parents = parents + self.cardinality = cardinality + self.cpd = cpd + + self.pi_agg = None # np.array dimensions [1, cardinality] + self.lambda_agg = None # np.array dimensions [1, cardinality] + + self.pi_received_msgs = self._init_received_msgs(parents) + self.lambda_received_msgs = self._init_received_msgs(children) + + @classmethod + def from_cpd_class(cls, + label_id, + children, + parents, + cardinality, + cpd_class): + cpd = cpd_class(label_id, parents) + return cls(label_id, children, parents, cardinality, cpd) + + @property + def belief(self): + if self.pi_agg.any() and self.lambda_agg.any(): + belief = np.multiply(self.pi_agg, self.lambda_agg) + return self._normalize(belief) + else: + return None + + def _normalize(self, value): + return value/value.sum() + + @staticmethod + def _init_received_msgs(keys): + return {k: None for k in keys} + + def _return_msgs_received_for_msg_type(self, message_type): + """ + Input: + message_type: MessageType enum + + Returns: + msg_values: list of message values (each an np.array) + """ + if message_type == MessageType.LAMBDA: + msg_values = [msg for msg in self.lambda_received_msgs.values()] + elif message_type == MessageType.PI: + msg_values = [msg for msg in self.pi_received_msgs.values()] + return msg_values + + def validate_and_return_msgs_received_for_msg_type(self, message_type): + """ + Check that all messages have been received from children (parents). + Raise error if all messages have not been received. Called + before calculating lambda_agg (pi_agg). + + Input: + message_type: MessageType enum + + Returns: + msg_values: list of message values (each an np.array) + """ + msg_values = self._return_msgs_received_for_msg_type(message_type) + + if any(msg is None for msg in msg_values): + raise ValueError( + "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg.". + format(msg_type=message_type.value) + ) + else: + return msg_values + + def compute_pi_agg(self): + # TODO: implement explict factor product operation + raise NotImplementedError + + def compute_lambda_agg(self): + if not self.children: + return self.lambda_agg + else: + lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) + self.lambda_agg = reduce(np.multiply, lambda_msg_values) + return self.lambda_agg + + def _update_received_msg_by_key(self, received_msg_dict, key, new_value): + if key not in received_msg_dict.keys(): + raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}". + format(key, received_msg_dict.keys())) + + if not isinstance(new_value, np.ndarray): + raise TypeError("Expected a new value of type numpy.ndarray, but got type {}". + format(type(new_value))) + + if new_value.shape != (self.cardinality,): + raise ValueError("Expected new value to be of dimensions ({},) but got {} instead". + format(self.cardinality, new_value.shape)) + received_msg_dict[key] = new_value + + def update_pi_msg_from_parent(self, parent, new_value): + self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs, + key=parent, + new_value=new_value) + + def update_lambda_msg_from_child(self, child, new_value): + self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs, + key=child, + new_value=new_value) + + def compute_pi_msg_to_child(self, child_k): + lambda_msg_from_child = self.lambda_received_msgs[child_k] + if lambda_msg_from_child is not None: + with np.errstate(divide='ignore', invalid='ignore'): + # 0/0 := 0 + return self._normalize( + np.nan_to_num(np.divide(self.belief, lambda_msg_from_child))) + else: + raise ValueError("Can't compute pi message to child_{} without having received" \ + "a lambda message from that child.") + + def compute_lambda_msg_to_parent(self, parent_k): + # TODO: implement explict factor product operation + raise NotImplementedError + + @property + def is_fully_initialized(self): + """ + Returns True if all lambda and pi messages and lambda_agg and + pi_agg are not None, else False. + """ + lambda_msgs = self._return_msgs_received_for_msg_type(MessageType.LAMBDA) + if any(msg is None for msg in lambda_msgs): + return False + + pi_msgs = self._return_msgs_received_for_msg_type(MessageType.PI) + if any(msg is None for msg in pi_msgs): + return False + + if (self.pi_agg is None) or (self.lambda_agg is None): + return False + + return True |