From 8dc7ae89677fca16ee974a30cff8c4df53c955ce Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Sun, 3 Dec 2017 19:16:32 -0800 Subject: PR comments --- beliefs/models/base_models.py | 154 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 beliefs/models/base_models.py (limited to 'beliefs/models/base_models.py') diff --git a/beliefs/models/base_models.py b/beliefs/models/base_models.py new file mode 100644 index 0000000..cb91566 --- /dev/null +++ b/beliefs/models/base_models.py @@ -0,0 +1,154 @@ +import networkx as nx + +from beliefs.utils.math_helper import is_kronecker_delta + + +class DirectedGraph(nx.DiGraph): + """ + Base class for all directed graphical models. + """ + def __init__(self, edges=None, node_labels=None): + """ + Input: + edges: an edge list, e.g. [(parent1, child1), (parent1, child2)] + node_labels: a list of strings of node labels + """ + super().__init__() + if edges is not None: + self.add_edges_from(edges) + if node_labels is not None: + self.add_nodes_from(node_labels) + + def get_leaves(self): + """ + Returns a list of leaves of the graph. + """ + return [node for node, out_degree in self.out_degree() if out_degree == 0] + + def get_roots(self): + """ + Returns a list of roots of the graph. + """ + return [node for node, in_degree in self.in_degree() if in_degree == 0] + + def get_topologically_sorted_nodes(self, reverse=False): + if reverse: + return list(reversed(list(nx.topological_sort(self)))) + else: + return nx.topological_sort(self) + + +class BayesianModel(DirectedGraph): + """ + Bayesian model stores nodes and edges described by conditional probability + distributions. + """ + def __init__(self, edges=[], variables=[], cpds=[]): + """ + Base class for Bayesian model. + + Input: + edges: (optional) list of edges, + tuples of form ('parent', 'child') + variables: (optional) list of str or int + labels for variables + cpds: (optional) list of CPDs + TabularCPD class or subclass + """ + super().__init__() + super().add_edges_from(edges) + super().add_nodes_from(variables) + self.cpds = cpds + + def copy(self): + """ + Returns a copy of the model. + """ + copy_model = self.__class__(edges=list(self.edges()).copy(), + variables=list(self.nodes()).copy(), + cpds=[cpd.copy() for cpd in self.cpds]) + return copy_model + + def get_variables_in_definite_state(self): + """ + Returns a set of labels of all nodes in a definite state, i.e. with + label values that are kronecker deltas. + + RETURNS + set of strings (labels) + """ + return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)} + + def get_unobserved_variables_in_definite_state(self, observed=set()): + """ + Returns a set of labels that are inferred to be in definite state, given + list of labels that were directly observed (e.g. YES/NOs, but not MAYBEs). + + INPUT + observed: set of strings, directly observed labels + RETURNS + set of strings, labels inferred to be in a definite state + """ + + # Assert that beliefs of directly observed vars are kronecker deltas + for label in observed: + assert is_kronecker_delta(self.nodes_dict[label].belief), \ + ("Observed label has belief {} but should be kronecker delta" + .format(self.nodes_dict[label].belief)) + + vars_in_definite_state = self.get_variables_in_definite_state() + assert observed <= vars_in_definite_state, \ + "Expected set of observed labels to be a subset of labels in definite state." + return vars_in_definite_state - observed + + def _get_ancestors_of(self, observed): + """Return list of ancestors of observed labels""" + ancestors = set() + for label in observed: + ancestors.update(nx.ancestors(self, label)) + return ancestors + + def reachable_observed_variables(self, source, observed=set()): + """ + Returns list of observed labels (labels with direct evidence to be in a definite + state) that are reachable from the source. + + INPUT + source: string, label of node for which to evaluate reachable observed labels + observed: set of strings, directly observed labels + RETURNS + reachable_observed_vars: set of strings, observed labels (variables with direct + evidence) that are reachable from the source label. + """ + # ancestors of observed labels, including observed labels + ancestors_of_observed = self._get_ancestors_of(observed) + ancestors_of_observed.update(observed) + + visit_list = set() + visit_list.add((source, 'up')) + traversed_list = set() + reachable_observed_vars = set() + + while visit_list: + node, direction = visit_list.pop() + if (node, direction) not in traversed_list: + if node in observed: + reachable_observed_vars.add(node) + traversed_list.add((node, direction)) + if direction == 'up' and node not in observed: + for parent in self.predecessors(node): + # causal flow + visit_list.add((parent, 'up')) + for child in self.successors(node): + # common cause flow + visit_list.add((child, 'down')) + elif direction == 'down': + if node not in observed: + # evidential flow + for child in self.successors(node): + visit_list.add((child, 'down')) + if node in ancestors_of_observed: + # common effect flow (activated v-structure) + for parent in self.predecessors(node): + visit_list.add((parent, 'up')) + return reachable_observed_vars -- cgit v1.2.3