aboutsummaryrefslogblamecommitdiff
path: root/beliefs/types/Node.py
blob: a496acf6b9d0a66d780ab925bce941123f90e5de (plain) (tree)


































                                                                                    
                                                  














































































































































                                                                                                       
import numpy as np
from functools import reduce
from enum import Enum


class InvalidLambdaMsgToParent(Exception):
    """Computed invalid lambda msg to send to parent."""
    pass


class MessageType(Enum):
    LAMBDA = 'lambda'
    PI = 'pi'


class Node:
    """A node in a DAG with methods to compute the belief (marginal probability
    of the node given evidence) and compute pi/lambda messages to/from its neighbors
    to update its belief.

    Implemented from Pearl's belief propagation algorithm.
    """
    def __init__(self,
                 label_id,
                 children,
                 parents,
                 cardinality,
                 cpd):
        """
        Input:
            label_id: str
            children: str
            parents: set of strings
            cardinality: int, cardinality of the random variable the node represents
            cpd: an instance of a conditional probability distribution,
                 e.g. BernoulliOrCPD or TabularCPD
        """
        self.label_id = label_id
        self.children = children
        self.parents = parents
        self.cardinality = cardinality
        self.cpd = cpd

        self.pi_agg = None  # np.array dimensions [1, cardinality]
        self.lambda_agg = None  # np.array dimensions [1, cardinality]

        self.pi_received_msgs = self._init_received_msgs(parents)
        self.lambda_received_msgs = self._init_received_msgs(children)

    @classmethod
    def from_cpd_class(cls,
                       label_id,
                       children,
                       parents,
                       cardinality,
                       cpd_class):
        cpd = cpd_class(label_id, parents)
        return cls(label_id, children, parents, cardinality, cpd)

    @property
    def belief(self):
        if self.pi_agg.any() and self.lambda_agg.any():
            belief = np.multiply(self.pi_agg, self.lambda_agg)
            return self._normalize(belief)
        else:
            return None

    def _normalize(self, value):
        return value/value.sum()

    @staticmethod
    def _init_received_msgs(keys):
        return {k: None for k in keys}

    def _return_msgs_received_for_msg_type(self, message_type):
        """
        Input:
          message_type: MessageType enum

        Returns:
          msg_values: list of message values (each an np.array)
        """
        if message_type == MessageType.LAMBDA:
            msg_values = [msg for msg in self.lambda_received_msgs.values()]
        elif message_type == MessageType.PI:
            msg_values = [msg for msg in self.pi_received_msgs.values()]
        return msg_values

    def validate_and_return_msgs_received_for_msg_type(self, message_type):
        """
        Check that all messages have been received from children (parents).
        Raise error if all messages have not been received.  Called
        before calculating lambda_agg (pi_agg).

        Input:
          message_type: MessageType enum

        Returns:
          msg_values: list of message values (each an np.array)
        """
        msg_values = self._return_msgs_received_for_msg_type(message_type)

        if any(msg is None for msg in msg_values):
            raise ValueError(
                "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg.".
                format(msg_type=message_type.value)
            )
        else:
            return msg_values

    def compute_pi_agg(self):
        # TODO: implement explict factor product operation
        raise NotImplementedError

    def compute_lambda_agg(self):
        if not self.children:
            return self.lambda_agg
        else:
            lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
            self.lambda_agg = reduce(np.multiply, lambda_msg_values)
        return self.lambda_agg

    def _update_received_msg_by_key(self, received_msg_dict, key, new_value):
        if key not in received_msg_dict.keys():
            raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}".
                             format(key, received_msg_dict.keys()))

        if not isinstance(new_value, np.ndarray):
            raise TypeError("Expected a new value of type numpy.ndarray, but got type {}".
                            format(type(new_value)))

        if new_value.shape != (self.cardinality,):
            raise ValueError("Expected new value to be of dimensions ({},) but got {} instead".
                             format(self.cardinality, new_value.shape))
        received_msg_dict[key] = new_value

    def update_pi_msg_from_parent(self, parent, new_value):
        self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
                                         key=parent,
                                         new_value=new_value)

    def update_lambda_msg_from_child(self, child, new_value):
        self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs,
                                         key=child,
                                         new_value=new_value)

    def compute_pi_msg_to_child(self, child_k):
        lambda_msg_from_child = self.lambda_received_msgs[child_k]
        if lambda_msg_from_child is not None:
            with np.errstate(divide='ignore', invalid='ignore'):
                # 0/0 := 0
                return self._normalize(
                    np.nan_to_num(np.divide(self.belief, lambda_msg_from_child)))
        else:
            raise ValueError("Can't compute pi message to child_{} without having received" \
                             "a lambda message from that child.")

    def compute_lambda_msg_to_parent(self, parent_k):
        # TODO: implement explict factor product operation
        raise NotImplementedError

    @property
    def is_fully_initialized(self):
        """
        Returns True if all lambda and pi messages and lambda_agg and
        pi_agg are not None, else False.
        """
        lambda_msgs = self._return_msgs_received_for_msg_type(MessageType.LAMBDA)
        if any(msg is None for msg in lambda_msgs):
            return False

        pi_msgs = self._return_msgs_received_for_msg_type(MessageType.PI)
        if any(msg is None for msg in pi_msgs):
            return False

        if (self.pi_agg is None) or (self.lambda_agg is None):
            return False

        return True