import networkx as nx from beliefs.utils.math_helper import is_kronecker_delta class DirectedGraph(nx.DiGraph): """ Base class for all directed graphical models. """ def __init__(self, edges=None, node_labels=None): """ Input: edges: an edge list, e.g. [(parent1, child1), (parent1, child2)] node_labels: a list of strings of node labels """ super().__init__() if edges is not None: self.add_edges_from(edges) if node_labels is not None: self.add_nodes_from(node_labels) def get_leaves(self): """ Returns a list of leaves of the graph. """ return [node for node, out_degree in self.out_degree() if out_degree == 0] def get_roots(self): """ Returns a list of roots of the graph. """ return [node for node, in_degree in self.in_degree() if in_degree == 0] def get_topologically_sorted_nodes(self, reverse=False): if reverse: return list(reversed(list(nx.topological_sort(self)))) else: return nx.topological_sort(self) class BayesianModel(DirectedGraph): """ Bayesian model stores nodes and edges described by conditional probability distributions. """ def __init__(self, edges=[], variables=[], cpds=[]): """ Base class for Bayesian model. Input: edges: (optional) list of edges, tuples of form ('parent', 'child') variables: (optional) list of str or int labels for variables cpds: (optional) list of CPDs TabularCPD class or subclass """ super().__init__() super().add_edges_from(edges) super().add_nodes_from(variables) self.cpds = cpds def copy(self): """ Returns a copy of the model. """ copy_model = self.__class__(edges=list(self.edges()).copy(), variables=list(self.nodes()).copy(), cpds=[cpd.copy() for cpd in self.cpds]) return copy_model def get_variables_in_definite_state(self): """ Returns a set of labels of all nodes in a definite state, i.e. with label values that are kronecker deltas. RETURNS set of strings (labels) """ return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)} def get_unobserved_variables_in_definite_state(self, observed=set()): """ Returns a set of labels that are inferred to be in definite state, given list of labels that were directly observed (e.g. YES/NOs, but not MAYBEs). INPUT observed: set of strings, directly observed labels RETURNS set of strings, labels inferred to be in a definite state """ # Assert that beliefs of directly observed vars are kronecker deltas for label in observed: assert is_kronecker_delta(self.nodes_dict[label].belief), \ ("Observed label has belief {} but should be kronecker delta" .format(self.nodes_dict[label].belief)) vars_in_definite_state = self.get_variables_in_definite_state() assert observed <= vars_in_definite_state, \ "Expected set of observed labels to be a subset of labels in definite state." return vars_in_definite_state - observed def _get_ancestors_of(self, observed): """Return list of ancestors of observed labels""" ancestors = set() for label in observed: ancestors.update(nx.ancestors(self, label)) return ancestors def reachable_observed_variables(self, source, observed=set()): """ Returns list of observed labels (labels with direct evidence to be in a definite state) that are reachable from the source. INPUT source: string, label of node for which to evaluate reachable observed labels observed: set of strings, directly observed labels RETURNS reachable_observed_vars: set of strings, observed labels (variables with direct evidence) that are reachable from the source label. """ # ancestors of observed labels, including observed labels ancestors_of_observed = self._get_ancestors_of(observed) ancestors_of_observed.update(observed) visit_list = set() visit_list.add((source, 'up')) traversed_list = set() reachable_observed_vars = set() while visit_list: node, direction = visit_list.pop() if (node, direction) not in traversed_list: if node in observed: reachable_observed_vars.add(node) traversed_list.add((node, direction)) if direction == 'up' and node not in observed: for parent in self.predecessors(node): # causal flow visit_list.add((parent, 'up')) for child in self.successors(node): # common cause flow visit_list.add((child, 'down')) elif direction == 'down': if node not in observed: # evidential flow for child in self.successors(node): visit_list.add((child, 'down')) if node in ancestors_of_observed: # common effect flow (activated v-structure) for parent in self.predecessors(node): visit_list.add((parent, 'up')) return reachable_observed_vars