aboutsummaryrefslogblamecommitdiff
path: root/beliefs/models/base_models.py
blob: 71af0cb15520800c3365bc9c636d9834cb765458 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11

                     


                                                        





                                                     




                                                                                 







                                            
                                                  


                                                                                  
                                                 


                                                                               
                                                              





                                                                  




                                                                              
                                                        
           
                                      
 

                                            
                                                  
                                                    
                                    
                                         





                                            

                   



                                                                     


                                              

                                                                           
 
               

                                 
                                                                                                    





                                                                                  




                                                                         
           
                              
                                                                          
                                                                       
                                                                             
                                                        





                                                                                         











                                                                 
                         
                            




                                                                   
                                                                                

                                                  








                                                                                

                                                                
                                                                         




























                                                                    
import networkx as nx

from beliefs.utils.math_helper import is_kronecker_delta


class DirectedGraph(nx.DiGraph):
    """
    Base class for all directed graphical models.
    """
    def __init__(self, edges=None, node_labels=None):
        """
        Args
            edges: list,
               a list of edge tuples, e.g. [(parent1, child1), (parent1, child2)]
            node_labels: list,
               a list of strings or integers representing node label ids
        """
        super().__init__()
        if edges is not None:
            self.add_edges_from(edges)
        if node_labels is not None:
            self.add_nodes_from(node_labels)

    def get_leaves(self):
        """Return a list of leaves of the graph"""
        return [node for node, out_degree in self.out_degree() if out_degree == 0]

    def get_roots(self):
        """Return a list of roots of the graph"""
        return [node for node, in_degree in self.in_degree() if in_degree == 0]

    def get_topologically_sorted_nodes(self, reverse=False):
        """Return a list of nodes in topological sort order"""
        if reverse:
            return list(reversed(list(nx.topological_sort(self))))
        else:
            return nx.topological_sort(self)


class BayesianModel(DirectedGraph):
    """
    Bayesian model stores nodes and edges described by conditional probability
    distributions.
    """
    def __init__(self, edges=[], variables=[], cpds=[]):
        """
        Base class for Bayesian model.

        Args
            edges: (optional) list of edges,
                tuples of form ('parent', 'child')
            variables: (optional) list of str or int
                labels for variables
            cpds: (optional) list of CPDs
                TabularCPD class or subclass
        """
        super().__init__()
        super().add_edges_from(edges)
        super().add_nodes_from(variables)
        self.cpds = cpds

    def copy(self):
        """Return a copy of the model"""
        return self.__class__(edges=list(self.edges()).copy(),
                              variables=list(self.nodes()).copy(),
                              cpds=[cpd.copy() for cpd in self.cpds])

    def get_variables_in_definite_state(self):
        """
        Get labels of all nodes in a definite state, i.e. with label values
        that are kronecker deltas.

        Returns
          set of strings (labels)
        """
        return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)}

    def get_unobserved_variables_in_definite_state(self, observed=set()):
        """
        Returns a set of labels that are inferred to be in definite state, given
        list of labels that were directly observed (e.g. YES/NOs, but not MAYBEs).

        Args
            observed: set,
                set of strings, directly observed labels
        Returns
            set of strings, the labels inferred to be in a definite state
        """
        for label in observed:
            # beliefs of directly observed vars should be kronecker deltas
            assert is_kronecker_delta(self.nodes_dict[label].belief), \
                ("Observed label has belief {} but should be kronecker delta"
                 .format(self.nodes_dict[label].belief))

        vars_in_definite_state = self.get_variables_in_definite_state()
        assert observed <= vars_in_definite_state, \
            "Expected set of observed labels to be a subset of labels in definite state."
        return vars_in_definite_state - observed

    def _get_ancestors_of(self, labels):
        """
        Get set of ancestors of an iterable of labels.

        Args
            observed: iterable,
                label ids for which ancestors should be retrieved

        Returns
            ancestors: set,
                set of label ids of ancestors of the input labels
        """
        ancestors = set()
        for label in labels:
            ancestors.update(nx.ancestors(self, label))
        return ancestors

    def reachable_observed_variables(self, source, observed=set()):
        """
        Get list of directly observed labels (labels with evidence in a definite
        state) that are reachable from the source.

        Args
            source: string,
                label of node for which to evaluate reachable observed labels
            observed: set,
                set of strings, directly observed labels
        Returns
            reachable_observed_vars: set,
                set of strings, observed labels (variables with direct evidence)
                that are reachable from the source label
        """
        ancestors_of_observed = self._get_ancestors_of(observed)
        ancestors_of_observed.update(observed)  # include observed labels

        visit_list = set()
        visit_list.add((source, 'up'))
        traversed_list = set()
        reachable_observed_vars = set()

        while visit_list:
            node, direction = visit_list.pop()
            if (node, direction) not in traversed_list:
                if node in observed:
                    reachable_observed_vars.add(node)
                traversed_list.add((node, direction))
                if direction == 'up' and node not in observed:
                    for parent in self.predecessors(node):
                        # causal flow
                        visit_list.add((parent, 'up'))
                    for child in self.successors(node):
                        # common cause flow
                        visit_list.add((child, 'down'))
                elif direction == 'down':
                    if node not in observed:
                        # evidential flow
                        for child in self.successors(node):
                            visit_list.add((child, 'down'))
                    if node in ancestors_of_observed:
                        # common effect flow (activated v-structure)
                        for parent in self.predecessors(node):
                            visit_list.add((parent, 'up'))
        return reachable_observed_vars