aboutsummaryrefslogtreecommitdiff
path: root/beliefs/types/Node.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/types/Node.py')
-rw-r--r--beliefs/types/Node.py179
1 files changed, 179 insertions, 0 deletions
diff --git a/beliefs/types/Node.py b/beliefs/types/Node.py
new file mode 100644
index 0000000..a8dca7c
--- /dev/null
+++ b/beliefs/types/Node.py
@@ -0,0 +1,179 @@
+import numpy as np
+from functools import reduce
+from enum import Enum
+
+
+class InvalidLambdaMsgToParent(Exception):
+ """Computed invalid lambda msg to send to parent."""
+ pass
+
+
+class MessageType(Enum):
+ LAMBDA = 'lambda'
+ PI = 'pi'
+
+
+class Node:
+ """A node in a DAG with methods to compute the belief (marginal probability
+ of the node given evidence) and compute pi/lambda messages to/from its neighbors
+ to update its belief.
+
+ Implemented from Pearl's belief propagation algorithm.
+ """
+ def __init__(self,
+ label_id,
+ children,
+ parents,
+ cardinality,
+ cpd):
+ """
+ Input:
+ label_id: str
+ children: str
+ parents: set of strings
+ cardinality: int, cardinality of the random variable the node represents
+ cpd: an instance of a conditional probability distribution,
+ e.g. BernoulliOrFactor or pgmpy's TabularCPD
+ """
+ self.label_id = label_id
+ self.children = children
+ self.parents = parents
+ self.cardinality = cardinality
+ self.cpd = cpd
+
+ self.pi_agg = None # np.array dimensions [1, cardinality]
+ self.lambda_agg = None # np.array dimensions [1, cardinality]
+
+ self.pi_received_msgs = self._init_received_msgs(parents)
+ self.lambda_received_msgs = self._init_received_msgs(children)
+
+ @classmethod
+ def from_cpd_class(cls,
+ label_id,
+ children,
+ parents,
+ cardinality,
+ cpd_class):
+ cpd = cpd_class(label_id, parents)
+ return cls(label_id, children, parents, cardinality, cpd)
+
+ @property
+ def belief(self):
+ if self.pi_agg.any() and self.lambda_agg.any():
+ belief = np.multiply(self.pi_agg, self.lambda_agg)
+ return self._normalize(belief)
+ else:
+ return None
+
+ def _normalize(self, value):
+ return value/value.sum()
+
+ @staticmethod
+ def _init_received_msgs(keys):
+ return {k: None for k in keys}
+
+ def _return_msgs_received_for_msg_type(self, message_type):
+ """
+ Input:
+ message_type: MessageType enum
+
+ Returns:
+ msg_values: list of message values (each an np.array)
+ """
+ if message_type == MessageType.LAMBDA:
+ msg_values = [msg for msg in self.lambda_received_msgs.values()]
+ elif message_type == MessageType.PI:
+ msg_values = [msg for msg in self.pi_received_msgs.values()]
+ return msg_values
+
+ def validate_and_return_msgs_received_for_msg_type(self, message_type):
+ """
+ Check that all messages have been received from children (parents).
+ Raise error if all messages have not been received. Called
+ before calculating lambda_agg (pi_agg).
+
+ Input:
+ message_type: MessageType enum
+
+ Returns:
+ msg_values: list of message values (each an np.array)
+ """
+ msg_values = self._return_msgs_received_for_msg_type(message_type)
+
+ if any(msg is None for msg in msg_values):
+ raise ValueError(
+ "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg.".
+ format(msg_type=message_type.value)
+ )
+ else:
+ return msg_values
+
+ def compute_pi_agg(self):
+ # TODO: implement explict factor product operation
+ raise NotImplementedError
+
+ def compute_lambda_agg(self):
+ if not self.children:
+ return self.lambda_agg
+ else:
+ lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
+ self.lambda_agg = reduce(np.multiply, lambda_msg_values)
+ return self.lambda_agg
+
+ def _update_received_msg_by_key(self, received_msg_dict, key, new_value):
+ if key not in received_msg_dict.keys():
+ raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}".
+ format(key, received_msg_dict.keys()))
+
+ if not isinstance(new_value, np.ndarray):
+ raise TypeError("Expected a new value of type numpy.ndarray, but got type {}".
+ format(type(new_value)))
+
+ if new_value.shape != (self.cardinality,):
+ raise ValueError("Expected new value to be of dimensions ({},) but got {} instead".
+ format(self.cardinality, new_value.shape))
+ received_msg_dict[key] = new_value
+
+ def update_pi_msg_from_parent(self, parent, new_value):
+ self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
+ key=parent,
+ new_value=new_value)
+
+ def update_lambda_msg_from_child(self, child, new_value):
+ self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs,
+ key=child,
+ new_value=new_value)
+
+ def compute_pi_msg_to_child(self, child_k):
+ lambda_msg_from_child = self.lambda_received_msgs[child_k]
+ if lambda_msg_from_child is not None:
+ with np.errstate(divide='ignore', invalid='ignore'):
+ # 0/0 := 0
+ return self._normalize(
+ np.nan_to_num(np.divide(self.belief, lambda_msg_from_child)))
+ else:
+ raise ValueError("Can't compute pi message to child_{} without having received" \
+ "a lambda message from that child.")
+
+ def compute_lambda_msg_to_parent(self, parent_k):
+ # TODO: implement explict factor product operation
+ raise NotImplementedError
+
+ @property
+ def is_fully_initialized(self):
+ """
+ Returns True if all lambda and pi messages and lambda_agg and
+ pi_agg are not None, else False.
+ """
+ lambda_msgs = self._return_msgs_received_for_msg_type(MessageType.LAMBDA)
+ if any(msg is None for msg in lambda_msgs):
+ return False
+
+ pi_msgs = self._return_msgs_received_for_msg_type(MessageType.PI)
+ if any(msg is None for msg in pi_msgs):
+ return False
+
+ if (self.pi_agg is None) or (self.lambda_agg is None):
+ return False
+
+ return True