aboutsummaryrefslogtreecommitdiff
path: root/beliefs/models/base_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/models/base_models.py')
-rw-r--r--beliefs/models/base_models.py90
1 files changed, 49 insertions, 41 deletions
diff --git a/beliefs/models/base_models.py b/beliefs/models/base_models.py
index cb91566..71af0cb 100644
--- a/beliefs/models/base_models.py
+++ b/beliefs/models/base_models.py
@@ -9,9 +9,11 @@ class DirectedGraph(nx.DiGraph):
"""
def __init__(self, edges=None, node_labels=None):
"""
- Input:
- edges: an edge list, e.g. [(parent1, child1), (parent1, child2)]
- node_labels: a list of strings of node labels
+ Args
+ edges: list,
+ a list of edge tuples, e.g. [(parent1, child1), (parent1, child2)]
+ node_labels: list,
+ a list of strings or integers representing node label ids
"""
super().__init__()
if edges is not None:
@@ -20,18 +22,15 @@ class DirectedGraph(nx.DiGraph):
self.add_nodes_from(node_labels)
def get_leaves(self):
- """
- Returns a list of leaves of the graph.
- """
+ """Return a list of leaves of the graph"""
return [node for node, out_degree in self.out_degree() if out_degree == 0]
def get_roots(self):
- """
- Returns a list of roots of the graph.
- """
+ """Return a list of roots of the graph"""
return [node for node, in_degree in self.in_degree() if in_degree == 0]
def get_topologically_sorted_nodes(self, reverse=False):
+ """Return a list of nodes in topological sort order"""
if reverse:
return list(reversed(list(nx.topological_sort(self))))
else:
@@ -47,12 +46,12 @@ class BayesianModel(DirectedGraph):
"""
Base class for Bayesian model.
- Input:
- edges: (optional) list of edges,
+ Args
+ edges: (optional) list of edges,
tuples of form ('parent', 'child')
- variables: (optional) list of str or int
+ variables: (optional) list of str or int
labels for variables
- cpds: (optional) list of CPDs
+ cpds: (optional) list of CPDs
TabularCPD class or subclass
"""
super().__init__()
@@ -61,20 +60,17 @@ class BayesianModel(DirectedGraph):
self.cpds = cpds
def copy(self):
- """
- Returns a copy of the model.
- """
- copy_model = self.__class__(edges=list(self.edges()).copy(),
- variables=list(self.nodes()).copy(),
- cpds=[cpd.copy() for cpd in self.cpds])
- return copy_model
+ """Return a copy of the model"""
+ return self.__class__(edges=list(self.edges()).copy(),
+ variables=list(self.nodes()).copy(),
+ cpds=[cpd.copy() for cpd in self.cpds])
def get_variables_in_definite_state(self):
"""
- Returns a set of labels of all nodes in a definite state, i.e. with
- label values that are kronecker deltas.
+ Get labels of all nodes in a definite state, i.e. with label values
+ that are kronecker deltas.
- RETURNS
+ Returns
set of strings (labels)
"""
return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)}
@@ -84,14 +80,14 @@ class BayesianModel(DirectedGraph):
Returns a set of labels that are inferred to be in definite state, given
list of labels that were directly observed (e.g. YES/NOs, but not MAYBEs).
- INPUT
- observed: set of strings, directly observed labels
- RETURNS
- set of strings, labels inferred to be in a definite state
+ Args
+ observed: set,
+ set of strings, directly observed labels
+ Returns
+ set of strings, the labels inferred to be in a definite state
"""
-
- # Assert that beliefs of directly observed vars are kronecker deltas
for label in observed:
+ # beliefs of directly observed vars should be kronecker deltas
assert is_kronecker_delta(self.nodes_dict[label].belief), \
("Observed label has belief {} but should be kronecker delta"
.format(self.nodes_dict[label].belief))
@@ -101,28 +97,40 @@ class BayesianModel(DirectedGraph):
"Expected set of observed labels to be a subset of labels in definite state."
return vars_in_definite_state - observed
- def _get_ancestors_of(self, observed):
- """Return list of ancestors of observed labels"""
+ def _get_ancestors_of(self, labels):
+ """
+ Get set of ancestors of an iterable of labels.
+
+ Args
+ observed: iterable,
+ label ids for which ancestors should be retrieved
+
+ Returns
+ ancestors: set,
+ set of label ids of ancestors of the input labels
+ """
ancestors = set()
- for label in observed:
+ for label in labels:
ancestors.update(nx.ancestors(self, label))
return ancestors
def reachable_observed_variables(self, source, observed=set()):
"""
- Returns list of observed labels (labels with direct evidence to be in a definite
+ Get list of directly observed labels (labels with evidence in a definite
state) that are reachable from the source.
- INPUT
- source: string, label of node for which to evaluate reachable observed labels
- observed: set of strings, directly observed labels
- RETURNS
- reachable_observed_vars: set of strings, observed labels (variables with direct
- evidence) that are reachable from the source label.
+ Args
+ source: string,
+ label of node for which to evaluate reachable observed labels
+ observed: set,
+ set of strings, directly observed labels
+ Returns
+ reachable_observed_vars: set,
+ set of strings, observed labels (variables with direct evidence)
+ that are reachable from the source label
"""
- # ancestors of observed labels, including observed labels
ancestors_of_observed = self._get_ancestors_of(observed)
- ancestors_of_observed.update(observed)
+ ancestors_of_observed.update(observed) # include observed labels
visit_list = set()
visit_list.add((source, 'up'))