aboutsummaryrefslogtreecommitdiff
path: root/beliefs/models/base_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/models/base_models.py')
-rw-r--r--beliefs/models/base_models.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/beliefs/models/base_models.py b/beliefs/models/base_models.py
new file mode 100644
index 0000000..cb91566
--- /dev/null
+++ b/beliefs/models/base_models.py
@@ -0,0 +1,154 @@
+import networkx as nx
+
+from beliefs.utils.math_helper import is_kronecker_delta
+
+
+class DirectedGraph(nx.DiGraph):
+ """
+ Base class for all directed graphical models.
+ """
+ def __init__(self, edges=None, node_labels=None):
+ """
+ Input:
+ edges: an edge list, e.g. [(parent1, child1), (parent1, child2)]
+ node_labels: a list of strings of node labels
+ """
+ super().__init__()
+ if edges is not None:
+ self.add_edges_from(edges)
+ if node_labels is not None:
+ self.add_nodes_from(node_labels)
+
+ def get_leaves(self):
+ """
+ Returns a list of leaves of the graph.
+ """
+ return [node for node, out_degree in self.out_degree() if out_degree == 0]
+
+ def get_roots(self):
+ """
+ Returns a list of roots of the graph.
+ """
+ return [node for node, in_degree in self.in_degree() if in_degree == 0]
+
+ def get_topologically_sorted_nodes(self, reverse=False):
+ if reverse:
+ return list(reversed(list(nx.topological_sort(self))))
+ else:
+ return nx.topological_sort(self)
+
+
+class BayesianModel(DirectedGraph):
+ """
+ Bayesian model stores nodes and edges described by conditional probability
+ distributions.
+ """
+ def __init__(self, edges=[], variables=[], cpds=[]):
+ """
+ Base class for Bayesian model.
+
+ Input:
+ edges: (optional) list of edges,
+ tuples of form ('parent', 'child')
+ variables: (optional) list of str or int
+ labels for variables
+ cpds: (optional) list of CPDs
+ TabularCPD class or subclass
+ """
+ super().__init__()
+ super().add_edges_from(edges)
+ super().add_nodes_from(variables)
+ self.cpds = cpds
+
+ def copy(self):
+ """
+ Returns a copy of the model.
+ """
+ copy_model = self.__class__(edges=list(self.edges()).copy(),
+ variables=list(self.nodes()).copy(),
+ cpds=[cpd.copy() for cpd in self.cpds])
+ return copy_model
+
+ def get_variables_in_definite_state(self):
+ """
+ Returns a set of labels of all nodes in a definite state, i.e. with
+ label values that are kronecker deltas.
+
+ RETURNS
+ set of strings (labels)
+ """
+ return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)}
+
+ def get_unobserved_variables_in_definite_state(self, observed=set()):
+ """
+ Returns a set of labels that are inferred to be in definite state, given
+ list of labels that were directly observed (e.g. YES/NOs, but not MAYBEs).
+
+ INPUT
+ observed: set of strings, directly observed labels
+ RETURNS
+ set of strings, labels inferred to be in a definite state
+ """
+
+ # Assert that beliefs of directly observed vars are kronecker deltas
+ for label in observed:
+ assert is_kronecker_delta(self.nodes_dict[label].belief), \
+ ("Observed label has belief {} but should be kronecker delta"
+ .format(self.nodes_dict[label].belief))
+
+ vars_in_definite_state = self.get_variables_in_definite_state()
+ assert observed <= vars_in_definite_state, \
+ "Expected set of observed labels to be a subset of labels in definite state."
+ return vars_in_definite_state - observed
+
+ def _get_ancestors_of(self, observed):
+ """Return list of ancestors of observed labels"""
+ ancestors = set()
+ for label in observed:
+ ancestors.update(nx.ancestors(self, label))
+ return ancestors
+
+ def reachable_observed_variables(self, source, observed=set()):
+ """
+ Returns list of observed labels (labels with direct evidence to be in a definite
+ state) that are reachable from the source.
+
+ INPUT
+ source: string, label of node for which to evaluate reachable observed labels
+ observed: set of strings, directly observed labels
+ RETURNS
+ reachable_observed_vars: set of strings, observed labels (variables with direct
+ evidence) that are reachable from the source label.
+ """
+ # ancestors of observed labels, including observed labels
+ ancestors_of_observed = self._get_ancestors_of(observed)
+ ancestors_of_observed.update(observed)
+
+ visit_list = set()
+ visit_list.add((source, 'up'))
+ traversed_list = set()
+ reachable_observed_vars = set()
+
+ while visit_list:
+ node, direction = visit_list.pop()
+ if (node, direction) not in traversed_list:
+ if node in observed:
+ reachable_observed_vars.add(node)
+ traversed_list.add((node, direction))
+ if direction == 'up' and node not in observed:
+ for parent in self.predecessors(node):
+ # causal flow
+ visit_list.add((parent, 'up'))
+ for child in self.successors(node):
+ # common cause flow
+ visit_list.add((child, 'down'))
+ elif direction == 'down':
+ if node not in observed:
+ # evidential flow
+ for child in self.successors(node):
+ visit_list.add((child, 'down'))
+ if node in ancestors_of_observed:
+ # common effect flow (activated v-structure)
+ for parent in self.predecessors(node):
+ visit_list.add((parent, 'up'))
+ return reachable_observed_vars