diff options
author | Cathy Yeh <cathy@driver.xyz> | 2017-11-13 14:42:52 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2017-11-17 13:48:16 -0800 |
commit | 77d8b323d4f6e05ca97d9cbef43ac85fd8040d61 (patch) | |
tree | bd589afff10efce13b6f017e544958454f3a8ef7 /beliefs/models/BayesianModel.py | |
parent | 6a1b35f5bf122232d058ed0f3ea19c15629c0cbc (diff) | |
download | beliefs-77d8b323d4f6e05ca97d9cbef43ac85fd8040d61.tar.gz beliefs-77d8b323d4f6e05ca97d9cbef43ac85fd8040d61.tar.bz2 beliefs-77d8b323d4f6e05ca97d9cbef43ac85fd8040d61.zip |
copy scripts from lgs branch
Diffstat (limited to 'beliefs/models/BayesianModel.py')
-rw-r--r-- | beliefs/models/BayesianModel.py | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/beliefs/models/BayesianModel.py b/beliefs/models/BayesianModel.py new file mode 100644 index 0000000..bdfd037 --- /dev/null +++ b/beliefs/models/BayesianModel.py @@ -0,0 +1,165 @@ +import copy +import numpy as np +import networkx as nx + +from beliefs.models.DirectedGraph import DirectedGraph +from beliefs.utils.edges_helper import EdgesHelper +from beliefs.utils.math_helper import is_kronecker_delta + + +class BayesianModel(DirectedGraph): + """ + Bayesian model stores nodes and edges described by conditional probability + distributions. + """ + def __init__(self, edges, nodes=None): + """ + Input: + edges: list of edge tuples of form ('parent', 'child') + nodes: (optional) dict + a dict key, value pair as {label_id: instance_of_node_class_or_subclass} + """ + if nodes is not None: + super().__init__(edges, nodes.keys()) + else: + super().__init__(edges) + self.nodes = nodes + + @classmethod + def from_node_class(cls, edges, node_class): + """Automatically create all nodes from the same node class + + Input: + edges: list of edge tuples of form ('parent', 'child') + node_class: (optional) the Node class or subclass from which to + create all the nodes from edges. + """ + nodes = cls.create_nodes(edges, node_class) + return cls.__init__(edges=edges, nodes=nodes) + + @staticmethod + def create_nodes(edges, node_class): + """Returns list of Node instances created from edges using + the default node_class""" + edges_helper = EdgesHelper(edges) + nodes = edges_helper.create_nodes_from_edges(node_class=node_class) + label_to_node = dict() + for node in nodes: + label_to_node[node.label_id] = node + return label_to_node + + def set_boundary_conditions(self): + """ + 1. Root nodes: if x is a node with no parents, set Pi(x) = prior + probability of x. + + 2. Leaf nodes: if x is a node with no children, set Lambda(x) + to an (unnormalized) unit vector, of length the cardinality of x. + """ + for root in self.get_roots(): + self.nodes[root].pi_agg = self.nodes[root].cpd.values + + for leaf in self.get_leaves(): + self.nodes[leaf].lambda_agg = np.ones([self.nodes[leaf].cardinality]) + + @property + def all_nodes_are_fully_initialized(self): + """ + Returns True if, for all nodes in the model, all lambda and pi + messages and lambda_agg and pi_agg are not None, else False. + """ + for node in self.nodes.values(): + if not node.is_fully_initialized: + return False + return True + + def copy(self): + """ + Returns a copy of the model. + """ + copy_edges = self.edges().copy() + copy_nodes = copy.deepcopy(self.nodes) + copy_model = self.__class__(edges=copy_edges, nodes=copy_nodes) + return copy_model + + def get_variables_in_definite_state(self): + """ + Returns a set of labels of all nodes in a definite state, i.e. with + label values that are kronecker deltas. + + RETURNS + set of strings (labels) + """ + return {label for label, node in self.nodes.items() if is_kronecker_delta(node.belief)} + + def get_unobserved_variables_in_definite_state(self, observed=set()): + """ + Returns a set of labels that are inferred to be in definite state, given + list of labels that were directly observed (e.g. YES/NOs, but not MAYBEs). + + INPUT + observed: set of strings, directly observed labels + RETURNS + set of strings, labels inferred to be in a definite state + """ + + # Assert that beliefs of directly observed vars are kronecker deltas + for label in observed: + assert is_kronecker_delta(self.nodes[label].belief), \ + ("Observed label has belief {} but should be kronecker delta" + .format(self.nodes[label].belief)) + + vars_in_definite_state = self.get_variables_in_definite_state() + assert observed <= vars_in_definite_state, \ + "Expected set of observed labels to be a subset of labels in definite state." + return vars_in_definite_state - observed + + def _get_ancestors_of(self, observed): + """Return list of ancestors of observed labels, including the observed labels themselves.""" + ancestors = observed.copy() + for label in observed: + ancestors.update(nx.ancestors(self, label)) + return ancestors + + def reachable_observed_variables(self, source, observed=set()): + """ + Returns list of observed labels (labels with direct evidence to be in a definite + state) that are reachable from the source. + + INPUT + source: string, label of node for which to evaluate reachable observed labels + observed: set of strings, directly observed labels + RETURNS + reachable_observed_vars: set of strings, observed labels (variables with direct + evidence) that are reachable from the source label. + """ + ancestors_of_observed = self._get_ancestors_of(observed) + + visit_list = set() + visit_list.add((source, 'up')) + traversed_list = set() + reachable_observed_vars = set() + + while visit_list: + node, direction = visit_list.pop() + if (node, direction) not in traversed_list: + if node in observed: + reachable_observed_vars.add(node) + traversed_list.add((node, direction)) + if direction == 'up' and node not in observed: + for parent in self.predecessors(node): + # causal flow + visit_list.add((parent, 'up')) + for child in self.successors(node): + # common cause flow + visit_list.add((child, 'down')) + elif direction == 'down': + if node not in observed: + # evidential flow + for child in self.successors(node): + visit_list.add((child, 'down')) + if node in ancestors_of_observed: + # common effect flow (activated v-structure) + for parent in self.predecessors(node): + visit_list.add((parent, 'up')) + return reachable_observed_vars |