diff options
Diffstat (limited to 'beliefs/models/base_models.py')
-rw-r--r-- | beliefs/models/base_models.py | 90 |
1 files changed, 49 insertions, 41 deletions
diff --git a/beliefs/models/base_models.py b/beliefs/models/base_models.py index cb91566..71af0cb 100644 --- a/beliefs/models/base_models.py +++ b/beliefs/models/base_models.py @@ -9,9 +9,11 @@ class DirectedGraph(nx.DiGraph): """ def __init__(self, edges=None, node_labels=None): """ - Input: - edges: an edge list, e.g. [(parent1, child1), (parent1, child2)] - node_labels: a list of strings of node labels + Args + edges: list, + a list of edge tuples, e.g. [(parent1, child1), (parent1, child2)] + node_labels: list, + a list of strings or integers representing node label ids """ super().__init__() if edges is not None: @@ -20,18 +22,15 @@ class DirectedGraph(nx.DiGraph): self.add_nodes_from(node_labels) def get_leaves(self): - """ - Returns a list of leaves of the graph. - """ + """Return a list of leaves of the graph""" return [node for node, out_degree in self.out_degree() if out_degree == 0] def get_roots(self): - """ - Returns a list of roots of the graph. - """ + """Return a list of roots of the graph""" return [node for node, in_degree in self.in_degree() if in_degree == 0] def get_topologically_sorted_nodes(self, reverse=False): + """Return a list of nodes in topological sort order""" if reverse: return list(reversed(list(nx.topological_sort(self)))) else: @@ -47,12 +46,12 @@ class BayesianModel(DirectedGraph): """ Base class for Bayesian model. - Input: - edges: (optional) list of edges, + Args + edges: (optional) list of edges, tuples of form ('parent', 'child') - variables: (optional) list of str or int + variables: (optional) list of str or int labels for variables - cpds: (optional) list of CPDs + cpds: (optional) list of CPDs TabularCPD class or subclass """ super().__init__() @@ -61,20 +60,17 @@ class BayesianModel(DirectedGraph): self.cpds = cpds def copy(self): - """ - Returns a copy of the model. - """ - copy_model = self.__class__(edges=list(self.edges()).copy(), - variables=list(self.nodes()).copy(), - cpds=[cpd.copy() for cpd in self.cpds]) - return copy_model + """Return a copy of the model""" + return self.__class__(edges=list(self.edges()).copy(), + variables=list(self.nodes()).copy(), + cpds=[cpd.copy() for cpd in self.cpds]) def get_variables_in_definite_state(self): """ - Returns a set of labels of all nodes in a definite state, i.e. with - label values that are kronecker deltas. + Get labels of all nodes in a definite state, i.e. with label values + that are kronecker deltas. - RETURNS + Returns set of strings (labels) """ return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)} @@ -84,14 +80,14 @@ class BayesianModel(DirectedGraph): Returns a set of labels that are inferred to be in definite state, given list of labels that were directly observed (e.g. YES/NOs, but not MAYBEs). - INPUT - observed: set of strings, directly observed labels - RETURNS - set of strings, labels inferred to be in a definite state + Args + observed: set, + set of strings, directly observed labels + Returns + set of strings, the labels inferred to be in a definite state """ - - # Assert that beliefs of directly observed vars are kronecker deltas for label in observed: + # beliefs of directly observed vars should be kronecker deltas assert is_kronecker_delta(self.nodes_dict[label].belief), \ ("Observed label has belief {} but should be kronecker delta" .format(self.nodes_dict[label].belief)) @@ -101,28 +97,40 @@ class BayesianModel(DirectedGraph): "Expected set of observed labels to be a subset of labels in definite state." return vars_in_definite_state - observed - def _get_ancestors_of(self, observed): - """Return list of ancestors of observed labels""" + def _get_ancestors_of(self, labels): + """ + Get set of ancestors of an iterable of labels. + + Args + observed: iterable, + label ids for which ancestors should be retrieved + + Returns + ancestors: set, + set of label ids of ancestors of the input labels + """ ancestors = set() - for label in observed: + for label in labels: ancestors.update(nx.ancestors(self, label)) return ancestors def reachable_observed_variables(self, source, observed=set()): """ - Returns list of observed labels (labels with direct evidence to be in a definite + Get list of directly observed labels (labels with evidence in a definite state) that are reachable from the source. - INPUT - source: string, label of node for which to evaluate reachable observed labels - observed: set of strings, directly observed labels - RETURNS - reachable_observed_vars: set of strings, observed labels (variables with direct - evidence) that are reachable from the source label. + Args + source: string, + label of node for which to evaluate reachable observed labels + observed: set, + set of strings, directly observed labels + Returns + reachable_observed_vars: set, + set of strings, observed labels (variables with direct evidence) + that are reachable from the source label """ - # ancestors of observed labels, including observed labels ancestors_of_observed = self._get_ancestors_of(observed) - ancestors_of_observed.update(observed) + ancestors_of_observed.update(observed) # include observed labels visit_list = set() visit_list.add((source, 'up')) |