aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-03 20:38:28 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-03 20:38:28 -0800
commit26b43410569044aff46053cae7c68862825dd4ec (patch)
treeb184df84d416e2ddf837b25baadff4f9feaaa250
parent6a1b35f5bf122232d058ed0f3ea19c15629c0cbc (diff)
parentc906bd37fba63ba706cc3b7802bfb18ffb05ee9a (diff)
downloadbeliefs-26b43410569044aff46053cae7c68862825dd4ec.tar.gz
beliefs-26b43410569044aff46053cae7c68862825dd4ec.tar.bz2
beliefs-26b43410569044aff46053cae7c68862825dd4ec.zip
LGS-164 belief propagation for polytrees, special case of OR cpds, refactored from LGSv0.0.2
-rw-r--r--Makefile2
-rw-r--r--VERSION2
-rw-r--r--beliefs/factors/__init__.py0
-rw-r--r--beliefs/factors/bernoulli_or_cpd.py42
-rw-r--r--beliefs/factors/cpd.py45
-rw-r--r--beliefs/inference/__init__.py0
-rw-r--r--beliefs/inference/belief_propagation.py201
-rw-r--r--beliefs/models/DirectedGraph.py35
-rw-r--r--beliefs/models/base_models.py154
-rw-r--r--beliefs/models/belief_update_node_model.py315
-rw-r--r--beliefs/utils/__init__.py0
-rw-r--r--beliefs/utils/math_helper.py19
-rw-r--r--beliefs/utils/random_variables.py21
-rw-r--r--tests/test_belief_propagation.py255
-rw-r--r--tests/test_get_reachable_observed_variables.py129
15 files changed, 1183 insertions, 37 deletions
diff --git a/Makefile b/Makefile
index 805f519..1f33fc4 100644
--- a/Makefile
+++ b/Makefile
@@ -104,7 +104,7 @@ test-in-clean-env: verify-conda-build-installed
# run tests in the current environment
test-in-current-env:
git lfs fetch
- echo TEST
+ pytest tests -vv
####################################################################################################
# helper commands
diff --git a/VERSION b/VERSION
index 8acdd82..4e379d2 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.0.1
+0.0.2
diff --git a/beliefs/factors/__init__.py b/beliefs/factors/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/beliefs/factors/__init__.py
diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py
new file mode 100644
index 0000000..bfb3a95
--- /dev/null
+++ b/beliefs/factors/bernoulli_or_cpd.py
@@ -0,0 +1,42 @@
+import numpy as np
+
+from beliefs.factors.cpd import TabularCPD
+
+
+class BernoulliOrCPD(TabularCPD):
+ """CPD class for a Bernoulli random variable whose relationship to its
+ parents (also Bernoulli random variables) is described by OR logic.
+
+ If at least one of the variable's parents is True, then the variable
+ is True, and False otherwise.
+ """
+ def __init__(self, variable, parents=[]):
+ """
+ Args:
+ variable: int or string
+ parents: optional, list of int and/or strings
+ """
+ super().__init__(variable=variable,
+ variable_card=2,
+ parents=parents,
+ parents_card=[2]*len(parents),
+ values=[])
+ self._values = []
+
+ @property
+ def values(self):
+ if not any(self._values):
+ self._values = self._build_kwise_values_array(len(self.variables))
+ self._values = self._values.reshape(self.cardinality)
+ return self._values
+
+ @staticmethod
+ def _build_kwise_values_array(k):
+ # special case a completely independent factor, and
+ # return the uniform prior
+ if k == 1:
+ return np.array([0.5, 0.5])
+
+ return np.array(
+ [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
+ )
diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py
new file mode 100644
index 0000000..a286aaa
--- /dev/null
+++ b/beliefs/factors/cpd.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+
+class TabularCPD:
+ """
+ Defines the conditional probability table for a discrete variable
+ whose parents are also discrete.
+
+ TODO: have this inherit from DiscreteFactor implementing explicit factor methods
+ """
+ def __init__(self, variable, variable_card,
+ parents=[], parents_card=[], values=[]):
+ """
+ Args:
+ variable: int or string
+ variable_card: int
+ parents: optional, list of int and/or strings
+ parents_card: optional, list of int
+ values: optional, 2d list or array
+ """
+ self.variable = variable
+ self.parents = parents
+ self.variables = [variable] + parents
+ self.cardinality = [variable_card] + parents_card
+ self._values = np.array(values)
+
+ @property
+ def values(self):
+ return self._values
+
+ def get_values(self):
+ """
+ Returns the tabular cpd form of the values.
+ """
+ if len(self.cardinality) == 1:
+ return self.values.reshape(1, np.prod(self.cardinality))
+ else:
+ return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
+
+ def copy(self):
+ return self.__class__(self.variable,
+ self.cardinality[0],
+ self.parents,
+ self.cardinality[1:],
+ self._values)
diff --git a/beliefs/inference/__init__.py b/beliefs/inference/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/beliefs/inference/__init__.py
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
new file mode 100644
index 0000000..7ec648d
--- /dev/null
+++ b/beliefs/inference/belief_propagation.py
@@ -0,0 +1,201 @@
+import numpy as np
+from collections import namedtuple
+import logging
+
+from beliefs.models.belief_update_node_model import (
+ InvalidLambdaMsgToParent,
+ BeliefUpdateNodeModel
+)
+from beliefs.utils.math_helper import is_kronecker_delta
+
+
+logger = logging.getLogger(__name__)
+
+
+MsgPassers = namedtuple('MsgPassers', ['msg_receiver', 'msg_sender'])
+
+
+class ConflictingEvidenceError(Exception):
+ """Failed to run belief propagation on label graph because of conflicting evidence."""
+ def __init__(self, evidence):
+ message = (
+ "Can't run belief propagation with conflicting evidence: {}"
+ .format(evidence)
+ )
+ super().__init__(message)
+
+
+class BeliefPropagation:
+ def __init__(self, model, inplace=True):
+ """
+ Input:
+ model: an instance of BeliefUpdateNodeModel
+ inplace: bool
+ modify in-place the nodes in the model during belief propagation
+ """
+ if not isinstance(model, BeliefUpdateNodeModel):
+ raise TypeError("Model must be an instance of BeliefUpdateNodeModel")
+ if inplace is False:
+ self.model = model.copy()
+ else:
+ self.model = model
+
+ def _belief_propagation(self, nodes_to_update, evidence):
+ """
+ Implementation of Pearl's belief propagation algorithm for polytrees.
+
+ ref: "Fusion, Propagation, and Structuring in Belief Networks"
+ Artificial Intelligence 29 (1986) 241-288
+
+ Input:
+ nodes_to_update: list
+ list of MsgPasser namedtuples.
+ evidence: dict,
+ a dict key, value pair as {var: state_of_var observed}
+ """
+ if len(nodes_to_update) == 0:
+ return
+
+ node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop()
+ logging.info("Node: %s", node_to_update_label_id)
+
+ node = self.model.nodes_dict[node_to_update_label_id]
+
+ # exclude the message sender (either a parent or child) from getting an
+ # outgoing msg from the node to update
+ parent_ids = set(node.parents) - set([msg_sender_label_id])
+ child_ids = set(node.children) - set([msg_sender_label_id])
+ logging.info("parent_ids: %s", str(parent_ids))
+ logging.info("child_ids: %s", str(child_ids))
+
+ if msg_sender_label_id is not None:
+ # update triggered by receiving a message, not pinning to evidence
+ assert len(node.parents) + len(node.children) - 1 == len(parent_ids) + len(child_ids)
+
+ if node_to_update_label_id not in evidence:
+ node.compute_pi_agg()
+ logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg))
+ node.compute_lambda_agg()
+ logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg))
+
+ for parent_id in parent_ids:
+ try:
+ new_lambda_msg = node.compute_lambda_msg_to_parent(parent_k=parent_id)
+ except InvalidLambdaMsgToParent:
+ raise ConflictingEvidenceError(evidence=evidence)
+
+ parent_node = self.model.nodes_dict[parent_id]
+ parent_node.update_lambda_msg_from_child(child=node_to_update_label_id,
+ new_value=new_lambda_msg)
+ nodes_to_update.add(MsgPassers(msg_receiver=parent_id,
+ msg_sender=node_to_update_label_id))
+
+ for child_id in child_ids:
+ new_pi_msg = node.compute_pi_msg_to_child(child_k=child_id)
+ child_node = self.model.nodes_dict[child_id]
+ child_node.update_pi_msg_from_parent(parent=node_to_update_label_id,
+ new_value=new_pi_msg)
+ nodes_to_update.add(MsgPassers(msg_receiver=child_id,
+ msg_sender=node_to_update_label_id))
+
+ self._belief_propagation(nodes_to_update, evidence)
+
+ def initialize_model(self):
+ """
+ Apply boundary conditions:
+ - Set pi_agg equal to prior probabilities for root nodes.
+ - Set lambda_agg equal to vector of ones for leaf nodes.
+
+ - Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as
+ actually passing lambda messages up from leaf nodes to root nodes).
+ - Calculate pi_agg and pi_received_msgs for all nodes without evidence.
+ (Without evidence, belief equals pi_agg.)
+ """
+ self.model.set_boundary_conditions()
+
+ for node in self.model.nodes_dict.values():
+ ones_vector = np.ones([node.cardinality])
+
+ node.lambda_agg = ones_vector
+ for child in node.lambda_received_msgs.keys():
+ node.update_lambda_msg_from_child(child=child,
+ new_value=ones_vector)
+ logging.info("Finished initializing Lambda(x) and lambda_received_msgs per node.")
+
+ logging.info("Start downward sweep from nodes. Sending Pi messages only.")
+ topdown_order = self.model.get_topologically_sorted_nodes(reverse=False)
+
+ for node_id in topdown_order:
+ logging.info('label in iteration through top-down order: %s', str(node_id))
+
+ node_sending_msg = self.model.nodes_dict[node_id]
+ child_ids = node_sending_msg.children
+
+ if node_sending_msg.pi_agg is None:
+ node_sending_msg.compute_pi_agg()
+
+ for child_id in child_ids:
+ logging.info("child: %s", str(child_id))
+ new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id)
+ logging.info("new_pi_msg: %s", np.array2string(new_pi_msg))
+
+ child_node = self.model.nodes_dict[child_id]
+ child_node.update_pi_msg_from_parent(parent=node_id,
+ new_value=new_pi_msg)
+
+ def _run_belief_propagation(self, evidence):
+ """
+ Input:
+ evidence: dict
+ a dict key, value pair as {var: state_of_var observed}
+ """
+ for evidence_id, observed_value in evidence.items():
+ nodes_to_update = set()
+
+ if evidence_id not in self.model.nodes_dict.keys():
+ raise KeyError("Evidence supplied for non-existent label_id: {}"
+ .format(evidence_id))
+
+ if is_kronecker_delta(observed_value):
+ # specific evidence
+ self.model.nodes_dict[evidence_id].lambda_agg = observed_value
+ else:
+ # virtual evidence
+ self.model.nodes_dict[evidence_id].lambda_agg = \
+ self.model.nodes_dict[evidence_id].lambda_agg * observed_value
+
+ nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)]
+
+ self._belief_propagation(nodes_to_update=set(nodes_to_update),
+ evidence=evidence)
+
+ def query(self, evidence={}):
+ """
+ Run belief propagation given evidence.
+
+ Input:
+ evidence: dict
+ a dict key, value pair as {var: state_of_var observed},
+ e.g. {'3': np.array([0,1])} if label '3' is True.
+
+ Returns:
+ beliefs: dict
+ a dict key, value pair as {var: belief}
+
+ Example
+ -------
+ >> import numpy as np
+ >> from beliefs.inference.belief_propagation import BeliefPropagation
+ >> from beliefs.models.belief_update_node_model import BeliefUpdateNodeModel, BernoulliOrNode
+ >> edges = [('1', '3'), ('2', '3'), ('3', '5')]
+ >> model = BeliefUpdateNodeModel.init_from_edges(edges, BernoulliOrNode)
+ >> infer = BeliefPropagation(model)
+ >> result = infer.query(evidence={'2': np.array([0, 1])})
+ """
+ if not self.model.all_nodes_are_fully_initialized:
+ self.initialize_model()
+
+ if evidence:
+ self._run_belief_propagation(evidence)
+
+ return {label_id: node.belief for label_id, node in self.model.nodes_dict.items()}
diff --git a/beliefs/models/DirectedGraph.py b/beliefs/models/DirectedGraph.py
deleted file mode 100644
index 8dfb9bd..0000000
--- a/beliefs/models/DirectedGraph.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import networkx as nx
-
-
-class DirectedGraph(nx.DiGraph):
- """
- Base class for all directed graphical models.
- """
- def __init__(self, edges, node_labels):
- """
- Input:
- edges: an edge list, e.g. [(parent1, child1), (parent1, child2)]
- node_labels: a list of strings of node labels
- """
- super().__init__()
- if edges is not None:
- self.add_edges_from(edges)
- if node_labels is not None:
- self.add_nodes_from(node_labels)
-
- def get_leaves(self):
- """
- Returns a list of leaves of the graph.
- """
- return [node for node, out_degree in self.out_degree_iter() if
- out_degree == 0]
-
- def get_roots(self):
- """
- Returns a list of roots of the graph.
- """
- return [node for node, in_degree in self.in_degree().items() if
- in_degree == 0]
-
- def get_topologically_sorted_nodes(self, reverse=False):
- return nx.topological_sort(self, reverse=reverse)
diff --git a/beliefs/models/base_models.py b/beliefs/models/base_models.py
new file mode 100644
index 0000000..cb91566
--- /dev/null
+++ b/beliefs/models/base_models.py
@@ -0,0 +1,154 @@
+import networkx as nx
+
+from beliefs.utils.math_helper import is_kronecker_delta
+
+
+class DirectedGraph(nx.DiGraph):
+ """
+ Base class for all directed graphical models.
+ """
+ def __init__(self, edges=None, node_labels=None):
+ """
+ Input:
+ edges: an edge list, e.g. [(parent1, child1), (parent1, child2)]
+ node_labels: a list of strings of node labels
+ """
+ super().__init__()
+ if edges is not None:
+ self.add_edges_from(edges)
+ if node_labels is not None:
+ self.add_nodes_from(node_labels)
+
+ def get_leaves(self):
+ """
+ Returns a list of leaves of the graph.
+ """
+ return [node for node, out_degree in self.out_degree() if out_degree == 0]
+
+ def get_roots(self):
+ """
+ Returns a list of roots of the graph.
+ """
+ return [node for node, in_degree in self.in_degree() if in_degree == 0]
+
+ def get_topologically_sorted_nodes(self, reverse=False):
+ if reverse:
+ return list(reversed(list(nx.topological_sort(self))))
+ else:
+ return nx.topological_sort(self)
+
+
+class BayesianModel(DirectedGraph):
+ """
+ Bayesian model stores nodes and edges described by conditional probability
+ distributions.
+ """
+ def __init__(self, edges=[], variables=[], cpds=[]):
+ """
+ Base class for Bayesian model.
+
+ Input:
+ edges: (optional) list of edges,
+ tuples of form ('parent', 'child')
+ variables: (optional) list of str or int
+ labels for variables
+ cpds: (optional) list of CPDs
+ TabularCPD class or subclass
+ """
+ super().__init__()
+ super().add_edges_from(edges)
+ super().add_nodes_from(variables)
+ self.cpds = cpds
+
+ def copy(self):
+ """
+ Returns a copy of the model.
+ """
+ copy_model = self.__class__(edges=list(self.edges()).copy(),
+ variables=list(self.nodes()).copy(),
+ cpds=[cpd.copy() for cpd in self.cpds])
+ return copy_model
+
+ def get_variables_in_definite_state(self):
+ """
+ Returns a set of labels of all nodes in a definite state, i.e. with
+ label values that are kronecker deltas.
+
+ RETURNS
+ set of strings (labels)
+ """
+ return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)}
+
+ def get_unobserved_variables_in_definite_state(self, observed=set()):
+ """
+ Returns a set of labels that are inferred to be in definite state, given
+ list of labels that were directly observed (e.g. YES/NOs, but not MAYBEs).
+
+ INPUT
+ observed: set of strings, directly observed labels
+ RETURNS
+ set of strings, labels inferred to be in a definite state
+ """
+
+ # Assert that beliefs of directly observed vars are kronecker deltas
+ for label in observed:
+ assert is_kronecker_delta(self.nodes_dict[label].belief), \
+ ("Observed label has belief {} but should be kronecker delta"
+ .format(self.nodes_dict[label].belief))
+
+ vars_in_definite_state = self.get_variables_in_definite_state()
+ assert observed <= vars_in_definite_state, \
+ "Expected set of observed labels to be a subset of labels in definite state."
+ return vars_in_definite_state - observed
+
+ def _get_ancestors_of(self, observed):
+ """Return list of ancestors of observed labels"""
+ ancestors = set()
+ for label in observed:
+ ancestors.update(nx.ancestors(self, label))
+ return ancestors
+
+ def reachable_observed_variables(self, source, observed=set()):
+ """
+ Returns list of observed labels (labels with direct evidence to be in a definite
+ state) that are reachable from the source.
+
+ INPUT
+ source: string, label of node for which to evaluate reachable observed labels
+ observed: set of strings, directly observed labels
+ RETURNS
+ reachable_observed_vars: set of strings, observed labels (variables with direct
+ evidence) that are reachable from the source label.
+ """
+ # ancestors of observed labels, including observed labels
+ ancestors_of_observed = self._get_ancestors_of(observed)
+ ancestors_of_observed.update(observed)
+
+ visit_list = set()
+ visit_list.add((source, 'up'))
+ traversed_list = set()
+ reachable_observed_vars = set()
+
+ while visit_list:
+ node, direction = visit_list.pop()
+ if (node, direction) not in traversed_list:
+ if node in observed:
+ reachable_observed_vars.add(node)
+ traversed_list.add((node, direction))
+ if direction == 'up' and node not in observed:
+ for parent in self.predecessors(node):
+ # causal flow
+ visit_list.add((parent, 'up'))
+ for child in self.successors(node):
+ # common cause flow
+ visit_list.add((child, 'down'))
+ elif direction == 'down':
+ if node not in observed:
+ # evidential flow
+ for child in self.successors(node):
+ visit_list.add((child, 'down'))
+ if node in ancestors_of_observed:
+ # common effect flow (activated v-structure)
+ for parent in self.predecessors(node):
+ visit_list.add((parent, 'up'))
+ return reachable_observed_vars
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
new file mode 100644
index 0000000..667e0f1
--- /dev/null
+++ b/beliefs/models/belief_update_node_model.py
@@ -0,0 +1,315 @@
+import copy
+from enum import Enum
+import numpy as np
+import itertools
+from functools import reduce
+
+import networkx as nx
+
+from beliefs.models.base_models import BayesianModel
+from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD
+
+
+class InvalidLambdaMsgToParent(Exception):
+ """Computed invalid lambda msg to send to parent."""
+ pass
+
+
+class MessageType(Enum):
+ LAMBDA = 'lambda'
+ PI = 'pi'
+
+
+class BeliefUpdateNodeModel(BayesianModel):
+ """
+ A Bayesian model storing nodes (e.g. Node or BernoulliOrNode) implementing properties
+ and methods for Pearl's belief update algorithm.
+
+ ref: "Fusion, Propagation, and Structuring in Belief Networks"
+ Artificial Intelligence 29 (1986) 241-288
+
+ """
+ def __init__(self, nodes_dict):
+ """
+ Input:
+ nodes_dict: dict
+ a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
+ """
+ super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()),
+ variables=list(nodes_dict.keys()),
+ cpds=[node.cpd for node in nodes_dict.values()])
+
+ self.nodes_dict = nodes_dict
+
+ @classmethod
+ def init_from_edges(cls, edges, node_class):
+ """Create nodes from the same node class.
+
+ Input:
+ edges: list of edge tuples of form ('parent', 'child')
+ node_class: the Node class or subclass from which to
+ create all the nodes from edges.
+ """
+ nodes = set()
+ g = nx.DiGraph(edges)
+
+ for label in set(itertools.chain(*edges)):
+ node = node_class(label_id=label,
+ children=list(g.successors(label)),
+ parents=list(g.predecessors(label)))
+ nodes.add(node)
+ nodes_dict = {node.label_id: node for node in nodes}
+ return cls(nodes_dict)
+
+ @staticmethod
+ def _get_edges_from_nodes(nodes):
+ """
+ Return list of all directed edges in nodes.
+
+ Args:
+ nodes: an iterable of objects of the Node class or subclass
+ Returns:
+ edges: list of edge tuples
+ """
+ edges = set()
+ for node in nodes:
+ if node.parents:
+ edge_tuples = zip(node.parents, [node.label_id]*len(node.parents))
+ edges.update(edge_tuples)
+ return list(edges)
+
+ def set_boundary_conditions(self):
+ """
+ 1. Root nodes: if x is a node with no parents, set Pi(x) = prior
+ probability of x.
+
+ 2. Leaf nodes: if x is a node with no children, set Lambda(x)
+ to an (unnormalized) unit vector, of length the cardinality of x.
+ """
+ for root in self.get_roots():
+ self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values
+
+ for leaf in self.get_leaves():
+ self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality])
+
+ @property
+ def all_nodes_are_fully_initialized(self):
+ """
+ Returns True if, for all nodes in the model, all lambda and pi
+ messages and lambda_agg and pi_agg are not None, else False.
+ """
+ for node in self.nodes_dict.values():
+ if not node.is_fully_initialized:
+ return False
+ return True
+
+ def copy(self):
+ """
+ Returns a copy of the model.
+ """
+ copy_nodes = copy.deepcopy(self.nodes_dict)
+ copy_model = self.__class__(nodes_dict=copy_nodes)
+ return copy_model
+
+
+class Node:
+ """A node in a DAG with methods to compute the belief (marginal probability
+ of the node given evidence) and compute pi/lambda messages to/from its neighbors
+ to update its belief.
+
+ Implemented from Pearl's belief propagation algorithm.
+ """
+ def __init__(self,
+ label_id,
+ children,
+ parents,
+ cardinality,
+ cpd):
+ """
+ Args
+ label_id: str
+ children: set of strings
+ parents: set of strings
+ cardinality: int, cardinality of the random variable the node represents
+ cpd: an instance of a conditional probability distribution,
+ e.g. BernoulliOrCPD or TabularCPD
+ """
+ self.label_id = label_id
+ self.children = children
+ self.parents = parents
+ self.cardinality = cardinality
+ self.cpd = cpd
+
+ self.pi_agg = None # np.array dimensions [1, cardinality]
+ self.lambda_agg = None # np.array dimensions [1, cardinality]
+
+ self.pi_received_msgs = self._init_received_msgs(parents)
+ self.lambda_received_msgs = self._init_received_msgs(children)
+
+ @classmethod
+ def from_cpd_class(cls,
+ label_id,
+ children,
+ parents,
+ cardinality,
+ cpd_class):
+ cpd = cpd_class(label_id, parents)
+ return cls(label_id, children, parents, cardinality, cpd)
+
+ @property
+ def belief(self):
+ if self.pi_agg.any() and self.lambda_agg.any():
+ belief = np.multiply(self.pi_agg, self.lambda_agg)
+ return self._normalize(belief)
+ else:
+ return None
+
+ def _normalize(self, value):
+ return value/value.sum()
+
+ @staticmethod
+ def _init_received_msgs(keys):
+ return {k: None for k in keys}
+
+ def _return_msgs_received_for_msg_type(self, message_type):
+ """
+ Input:
+ message_type: MessageType enum
+
+ Returns:
+ msg_values: list of message values (each an np.array)
+ """
+ if message_type == MessageType.LAMBDA:
+ msg_values = [msg for msg in self.lambda_received_msgs.values()]
+ elif message_type == MessageType.PI:
+ msg_values = [msg for msg in self.pi_received_msgs.values()]
+ return msg_values
+
+ def validate_and_return_msgs_received_for_msg_type(self, message_type):
+ """
+ Check that all messages have been received from children (parents).
+ Raise error if all messages have not been received. Called
+ before calculating lambda_agg (pi_agg).
+
+ Input:
+ message_type: MessageType enum
+
+ Returns:
+ msg_values: list of message values (each an np.array)
+ """
+ msg_values = self._return_msgs_received_for_msg_type(message_type)
+
+ if any(msg is None for msg in msg_values):
+ raise ValueError(
+ "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg."
+ .format(msg_type=message_type.value)
+ )
+ else:
+ return msg_values
+
+ def compute_pi_agg(self):
+ # TODO: implement explict factor product operation
+ raise NotImplementedError
+
+ def compute_lambda_agg(self):
+ if not self.children:
+ return self.lambda_agg
+ else:
+ lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
+ self.lambda_agg = reduce(np.multiply, lambda_msg_values)
+ return self.lambda_agg
+
+ def _update_received_msg_by_key(self, received_msg_dict, key, new_value):
+ if key not in received_msg_dict.keys():
+ raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}"
+ .format(key, received_msg_dict.keys()))
+
+ if not isinstance(new_value, np.ndarray):
+ raise TypeError("Expected a new value of type numpy.ndarray, but got type {}"
+ .format(type(new_value)))
+
+ if new_value.shape != (self.cardinality,):
+ raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
+ .format(self.cardinality, new_value.shape))
+ received_msg_dict[key] = new_value
+
+ def update_pi_msg_from_parent(self, parent, new_value):
+ self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
+ key=parent,
+ new_value=new_value)
+
+ def update_lambda_msg_from_child(self, child, new_value):
+ self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs,
+ key=child,
+ new_value=new_value)
+
+ def compute_pi_msg_to_child(self, child_k):
+ lambda_msg_from_child = self.lambda_received_msgs[child_k]
+ if lambda_msg_from_child is not None:
+ with np.errstate(divide='ignore', invalid='ignore'):
+ # 0/0 := 0
+ return self._normalize(
+ np.nan_to_num(np.divide(self.belief, lambda_msg_from_child)))
+ else:
+ raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.")
+
+ def compute_lambda_msg_to_parent(self, parent_k):
+ # TODO: implement explict factor product operation
+ raise NotImplementedError
+
+ @property
+ def is_fully_initialized(self):
+ """
+ Returns True if all lambda and pi messages and lambda_agg and
+ pi_agg are not None, else False.
+ """
+ lambda_msgs = self._return_msgs_received_for_msg_type(MessageType.LAMBDA)
+ if any(msg is None for msg in lambda_msgs):
+ return False
+
+ pi_msgs = self._return_msgs_received_for_msg_type(MessageType.PI)
+ if any(msg is None for msg in pi_msgs):
+ return False
+
+ if (self.pi_agg is None) or (self.lambda_agg is None):
+ return False
+
+ return True
+
+
+class BernoulliOrNode(Node):
+ def __init__(self,
+ label_id,
+ children,
+ parents):
+ super().__init__(label_id=label_id,
+ children=children,
+ parents=parents,
+ cardinality=2,
+ cpd=BernoulliOrCPD(label_id, parents))
+
+ def compute_pi_agg(self):
+ if not self.parents:
+ self.pi_agg = self.cpd.values
+ else:
+ pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ parents_p0 = [p[0] for p in pi_msg_values]
+ p_0 = reduce(lambda x, y: x*y, parents_p0)
+ p_1 = 1 - p_0
+ self.pi_agg = np.array([p_0, p_1])
+ return self.pi_agg
+
+ def compute_lambda_msg_to_parent(self, parent_k):
+ if np.array_equal(self.lambda_agg, np.ones([self.cardinality])):
+ return np.ones([self.cardinality])
+ else:
+ # TODO: cleanup this validation
+ _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k]
+ p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
+ lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product
+ lambda_1 = self.lambda_agg[1]
+ lambda_msg = np.array([lambda_0, lambda_1])
+ if not any(lambda_msg):
+ raise InvalidLambdaMsgToParent
+ return self._normalize(lambda_msg)
diff --git a/beliefs/utils/__init__.py b/beliefs/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/beliefs/utils/__init__.py
diff --git a/beliefs/utils/math_helper.py b/beliefs/utils/math_helper.py
new file mode 100644
index 0000000..a25ea68
--- /dev/null
+++ b/beliefs/utils/math_helper.py
@@ -0,0 +1,19 @@
+"""Random math utils."""
+
+
+def is_kronecker_delta(vector):
+ """Returns True if vector is a kronecker delta vector, False otherwise.
+ Specific evidence ('YES' or 'NO') is a kronecker delta vector, whereas
+ virtual evidence ('MAYBE') is not.
+ """
+ count = 0
+ for x in vector:
+ if x == 1:
+ count += 1
+ elif x != 0:
+ return False
+
+ if count == 1:
+ return True
+ else:
+ return False
diff --git a/beliefs/utils/random_variables.py b/beliefs/utils/random_variables.py
new file mode 100644
index 0000000..1a0b0f7
--- /dev/null
+++ b/beliefs/utils/random_variables.py
@@ -0,0 +1,21 @@
+
+
+def get_reachable_observed_variables_for_inferred_variables(model, observed=set()):
+ """
+ After performing inference on a BayesianModel, get the labels of observed variables
+ ("reachable observed variables") that influenced the beliefs of variables inferred
+ to be in a definite state.
+
+ INPUT
+ model: instance of BayesianModel class or subclass
+ observed: set of labels (strings) corresponding to vars pinned to definite
+ state during inference.
+ RETURNS
+ dict, of form key - source label (a string), value - a list of strings
+ """
+ if not observed:
+ return {}
+
+ source_vars = model.get_unobserved_variables_in_definite_state(observed)
+
+ return {var: model.reachable_observed_variables(var, observed) for var in source_vars}
diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py
new file mode 100644
index 0000000..5c5a612
--- /dev/null
+++ b/tests/test_belief_propagation.py
@@ -0,0 +1,255 @@
+import numpy as np
+import pytest
+from pytest import approx
+
+from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError
+from beliefs.models.belief_update_node_model import (
+ BeliefUpdateNodeModel,
+ BernoulliOrNode
+)
+
+
+@pytest.fixture(scope='module')
+def edges_four_nodes():
+ """Edges define a polytree with 4 nodes (connected in an X-shape with the
+ node, 'x', at the center of the X."""
+ edges = [('u', 'x'), ('v', 'x'), ('x', 'y'), ('x', 'z')]
+ return edges
+
+
+@pytest.fixture(scope='module')
+def simple_edges():
+ """Edges define a polytree with 15 nodes."""
+ edges = [('1', '3'), ('2', '3'), ('3', '5'), ('4', '5'), ('5', '10'),
+ ('5', '9'), ('6', '8'), ('7', '8'), ('8', '9'), ('9', '11'),
+ ('9', 'x'), ('14', 'x'), ('x', '12'), ('x', '13')]
+ return edges
+
+
+@pytest.fixture(scope='module')
+def many_parents_edges():
+ """Node 62 has 18 parents and no children."""
+ edges = [('96', '62'), ('80', '62'), ('98', '62'),
+ ('100', '62'), ('86', '62'), ('102', '62'), ('104', '62'),
+ ('64', '62'), ('106', '62'), ('108', '62'), ('110', '62'),
+ ('112', '62'), ('114', '62'), ('116', '62'), ('118', '62'),
+ ('122', '62'), ('70', '62'), ('94', '62')]
+ return edges
+
+
+@pytest.fixture(scope='function')
+def four_node_model(edges_four_nodes):
+ return BeliefUpdateNodeModel.init_from_edges(edges_four_nodes, BernoulliOrNode)
+
+
+@pytest.fixture(scope='function')
+def simple_model(simple_edges):
+ return BeliefUpdateNodeModel.init_from_edges(simple_edges, BernoulliOrNode)
+
+
+@pytest.fixture(scope='function')
+def many_parents_model(many_parents_edges):
+ return BeliefUpdateNodeModel.init_from_edges(many_parents_edges, BernoulliOrNode)
+
+
+@pytest.fixture(scope='function')
+def one_node_model():
+ a_node = BernoulliOrNode(label_id='x', children=[], parents=[])
+ return BeliefUpdateNodeModel(nodes_dict={'x': a_node})
+
+
+def get_label_mapped_to_positive_belief(query_result):
+ """Return a dictionary mapping each label_id to the probability of
+ the label being True."""
+ return {label_id: belief[1] for label_id, belief in query_result.items()}
+
+
+def compare_dictionaries(expected, observed):
+ for key, expected_value in expected.items():
+ observed_value = observed.get(key)
+ if observed_value is None:
+ raise KeyError("Expected key {} not in observed.")
+ assert observed_value == approx(expected_value), \
+ "Expected {} but got {}".format(expected_value, observed_value)
+
+
+#==============================================================================================
+# Tests of single Bernoulli node model
+
+def test_no_evidence_one_node_model(one_node_model):
+ expected = {'x': 0.5}
+ infer = BeliefPropagation(one_node_model)
+ query_result = infer.query(evidence={})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_virtual_evidence_one_node_model(one_node_model):
+ """Curator thinks YES is 10x more likely than NO based on virtual evidence."""
+ expected = {'x': 5/(0.5+5)}
+ infer = BeliefPropagation(one_node_model)
+ query_result = infer.query(evidence={'x': np.array([1, 10])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_MAYBE_default_evidence_one_node_model(one_node_model):
+ expected = {'x': 0.5}
+ infer = BeliefPropagation(one_node_model)
+ query_result = infer.query(evidence={'x': np.array([0.5, 0.5])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_YES_evidence_one_node_model(one_node_model):
+ expected = {'x': 1}
+ infer = BeliefPropagation(one_node_model)
+ query_result = infer.query(evidence={'x': np.array([0, 1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_NO_evidence_one_node_model(one_node_model):
+ expected = {'x': 0}
+ infer = BeliefPropagation(one_node_model)
+ query_result = infer.query(evidence={'x': np.array([1, 0])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+#==============================================================================================
+# Tests of 4-node, 4-edge model
+
+def test_no_evidence_four_node_model(four_node_model):
+ expected = {'x': 1-0.5**2}
+ infer = BeliefPropagation(four_node_model)
+ query_result = infer.query(evidence={})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_virtual_evidence_for_node_x_four_node_model(four_node_model):
+ """Virtual evidence for node x."""
+ expected = {'x': 0.967741935483871, 'y': 0.967741935483871, 'z': 0.967741935483871,
+ 'u': 0.6451612903225806, 'v': 0.6451612903225806}
+ infer = BeliefPropagation(four_node_model)
+ query_result = infer.query(evidence={'x': np.array([1, 10])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+#==============================================================================================
+# Tests of simple BernoulliOr polytree model
+
+def test_no_evidence_simple_model(simple_model):
+ expected = {'x': 0.984375, '14': 0.5, '7': 0.5, '2': 0.5, '3':
+ 0.75, '13': 0.984375, '6': 0.5, '4': 0.5, '8': 0.75, '10': 0.875,
+ '1': 0.5, '9': 0.96875, '12': 0.984375, '5': 0.875, '11': 0.96875}
+ infer = BeliefPropagation(simple_model)
+ query_result = infer.query(evidence={})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_belief_propagation_no_modify_model_inplace(simple_model):
+ assert not simple_model.all_nodes_are_fully_initialized
+ infer = BeliefPropagation(simple_model, inplace=False)
+ _ = infer.query(evidence={})
+ # after belief propagation, model node values should be unchanged
+ assert not simple_model.all_nodes_are_fully_initialized
+
+
+def test_belief_propagation_modify_model_inplace(simple_model):
+ assert not simple_model.all_nodes_are_fully_initialized
+ expected = {'x': 0.984375, '14': 0.5, '7': 0.5, '2': 0.5, '3':
+ 0.75, '13': 0.984375, '6': 0.5, '4': 0.5, '8': 0.75, '10': 0.875,
+ '1': 0.5, '9': 0.96875, '12': 0.984375, '5': 0.875, '11': 0.96875}
+ infer = BeliefPropagation(simple_model, inplace=True)
+ _ = infer.query(evidence={})
+
+ assert simple_model.all_nodes_are_fully_initialized
+ beliefs_from_model = {node_id: node.belief[1] for
+ node_id, node in simple_model.nodes_dict.items()}
+ compare_dictionaries(expected, beliefs_from_model)
+
+
+def test_positive_evidence_node_13(simple_model):
+ expected = {'6': 0.50793650793650791, '3': 0.76190476190476186,
+ '9': 0.98412698412698407, '8': 0.76190476190476186,
+ 'x': 1.0, '4': 0.50793650793650791, '11': 0.98412698412698407,
+ '1': 0.50793650793650791, '5': 0.88888888888888884,
+ '2': 0.50793650793650791, '12': 1.0,
+ '14': 0.50793650793650791, '13': 1,
+ '10': 0.88888888888888884, '7': 0.50793650793650791}
+ infer = BeliefPropagation(simple_model)
+ query_result = infer.query(evidence={'13': np.array([0, 1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_positive_evidence_node_5(simple_model):
+ expected = {'1': 0.5714285714285714, '5': 1, '3':
+ 0.8571428571428571, '10': 1.0, '8': 0.75, '2': 0.5714285714285714,
+ '4': 0.5714285714285714, '6': 0.5, '7': 0.5, '14': 0.5, '12': 1.0,
+ '13': 1.0, '11': 1.0, '9': 1.0, 'x': 1.0}
+ infer = BeliefPropagation(simple_model)
+ query_result = infer.query(evidence={'5': np.array([0, 1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_positive_evidence_node_5_negative_evidence_node_14(simple_model):
+ expected = {'6': 0.5, '7': 0.5, '9': 1.0, '3': 0.8571428571428571,
+ '1': 0.57142857142857151, '12': 1.0, 'x': 1.0, '11': 1.0, '14':
+ 0.0, '2': 0.57142857142857151, '4': 0.5714285714285714, '5': 1.0,
+ '10': 1.0, '13': 1.0, '8': 0.75}
+ infer = BeliefPropagation(simple_model)
+ query_result = infer.query(evidence={'5': np.array([0, 1]), '14': np.array([1, 0])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_conflicting_evidence(simple_model):
+ infer = BeliefPropagation(simple_model)
+ with pytest.raises(ConflictingEvidenceError) as err:
+ query_result = infer.query(evidence={'x': np.array([1, 0]), '5': np.array([0, 1])})
+ assert "Can't run belief propagation with conflicting evidence" in str(err)
+
+
+#==============================================================================================
+# Tests of model with 18 parents sharing a single child
+
+def test_no_evidence_many_parents_model(many_parents_model):
+ expected = {'64': 0.5, '86': 0.5, '62': 0.99999618530273438,
+ '116': 0.5, '100': 0.5, '108': 0.5, '122': 0.5, '114': 0.5, '98':
+ 0.5, '106': 0.5, '94': 0.5, '80': 0.5, '102': 0.5, '70': 0.5,
+ '118': 0.5, '96': 0.5, '104': 0.5, '110': 0.5, '112': 0.5}
+ infer = BeliefPropagation(many_parents_model)
+ query_result = infer.query(evidence={})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_positive_evidence_node_112(many_parents_model):
+ """If a single parent (112) is True, then (62) has to be True."""
+ expected = {'64': 0.5, '86': 0.5, '62': 1.0, '116': 0.5, '100':
+ 0.5, '108': 0.5, '122': 0.5, '114': 0.5, '98': 0.5,
+ '106': 0.5, '94': 0.5, '80': 0.5, '102': 0.5, '70':
+ 0.5, '118': 0.5, '96': 0.5, '104': 0.5, '110': 0.5,
+ '112': 1.0}
+ infer = BeliefPropagation(many_parents_model)
+ query_result = infer.query(evidence={'112': np.array([0, 1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_negative_evidence_node_62(many_parents_model):
+ """If node 62 is False, then all of its parents must be False."""
+ expected = {'64': 0, '86': 0, '62': 0, '116': 0, '100': 0, '108':
+ 0, '122': 0, '114': 0, '98': 0, '106': 0, '94': 0,
+ '80': 0, '102': 0, '70': 0, '118': 0, '96': 0, '104':
+ 0, '110': 0, '112': 0}
+ infer = BeliefPropagation(many_parents_model)
+ query_result = infer.query(evidence={'62': np.array([1, 0])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
diff --git a/tests/test_get_reachable_observed_variables.py b/tests/test_get_reachable_observed_variables.py
new file mode 100644
index 0000000..d6590ad
--- /dev/null
+++ b/tests/test_get_reachable_observed_variables.py
@@ -0,0 +1,129 @@
+import numpy as np
+
+from test_belief_propagation import simple_model, simple_edges
+
+from beliefs.inference.belief_propagation import BeliefPropagation
+from beliefs.utils.random_variables import (
+ get_reachable_observed_variables_for_inferred_variables
+)
+
+
+def test_reachable_observed_vars_direct_common_effect(simple_model):
+ observed_vars = {'14': np.array([1,0]), 'x': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'x', '14'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_indirect_common_effect(simple_model):
+ observed_vars = {'12': np.array([1,0]), '14': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'12', '14'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_common_cause(simple_model):
+ observed_vars = {'10': np.array([0,1])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'10'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_blocked_common_cause(simple_model):
+ observed_vars = {'10': np.array([0,1]), '5': np.array([0,1])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'5'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_indirect_causal(simple_model):
+ observed_vars = {'1': np.array([0,1]), '2': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'1', '2'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_blocked_causal(simple_model):
+ observed_vars = {'1': np.array([0,1]), '2': np.array([1,0]), '3': np.array([0,1])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'3'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_indirect_evidential(simple_model):
+ observed_vars = {'13': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'13'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_blocked_evidential(simple_model):
+ observed_vars = {'x': np.array([1,0]), '13': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'x'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_get_reachable_obs_vars_for_inferred(simple_model):
+ observed_vars = {'6': np.array([1,0]), '7': np.array([1,0]), '10': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ print(set(simple_model.get_unobserved_variables_in_definite_state(observed_vars.keys())))
+ print(simple_model._get_ancestors_of(set(observed_vars.keys())))
+ expected = {'4': {'10'}, '1': {'10'}, '11': {'7', '6', '10'}, '2': {'10'},
+ '8': {'7', '6'}, '5': {'10'}, '3': {'10'}, '9': {'7', '6', '10'}}
+
+ observed = get_reachable_observed_variables_for_inferred_variables(
+ model=simple_model,
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed