aboutsummaryrefslogtreecommitdiff
path: root/beliefs/inference/belief_propagation.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/inference/belief_propagation.py')
-rw-r--r--beliefs/inference/belief_propagation.py99
1 files changed, 48 insertions, 51 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index 7ec648d..e6e7b18 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -28,10 +28,10 @@ class ConflictingEvidenceError(Exception):
class BeliefPropagation:
def __init__(self, model, inplace=True):
"""
- Input:
- model: an instance of BeliefUpdateNodeModel
- inplace: bool
- modify in-place the nodes in the model during belief propagation
+ Args
+ model: an instance of BeliefUpdateNodeModel
+ inplace: bool,
+ modify in-place the nodes in the model during belief propagation
"""
if not isinstance(model, BeliefUpdateNodeModel):
raise TypeError("Model must be an instance of BeliefUpdateNodeModel")
@@ -43,21 +43,20 @@ class BeliefPropagation:
def _belief_propagation(self, nodes_to_update, evidence):
"""
Implementation of Pearl's belief propagation algorithm for polytrees.
-
ref: "Fusion, Propagation, and Structuring in Belief Networks"
Artificial Intelligence 29 (1986) 241-288
- Input:
- nodes_to_update: list
- list of MsgPasser namedtuples.
- evidence: dict,
- a dict key, value pair as {var: state_of_var observed}
+ Args
+ nodes_to_update: list,
+ list of MsgPasser namedtuples.
+ evidence: dict,
+ a dict key, value pair as {var: state_of_var observed}
"""
if len(nodes_to_update) == 0:
return
node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop()
- logging.info("Node: %s", node_to_update_label_id)
+ logging.debug("Node: %s", node_to_update_label_id)
node = self.model.nodes_dict[node_to_update_label_id]
@@ -65,18 +64,18 @@ class BeliefPropagation:
# outgoing msg from the node to update
parent_ids = set(node.parents) - set([msg_sender_label_id])
child_ids = set(node.children) - set([msg_sender_label_id])
- logging.info("parent_ids: %s", str(parent_ids))
- logging.info("child_ids: %s", str(child_ids))
+ logging.debug("parent_ids: %s", str(parent_ids))
+ logging.debug("child_ids: %s", str(child_ids))
if msg_sender_label_id is not None:
# update triggered by receiving a message, not pinning to evidence
assert len(node.parents) + len(node.children) - 1 == len(parent_ids) + len(child_ids)
if node_to_update_label_id not in evidence:
- node.compute_pi_agg()
- logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg))
- node.compute_lambda_agg()
- logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg))
+ node.compute_and_update_pi_agg()
+ logging.debug("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values))
+ node.compute_and_update_lambda_agg()
+ logging.debug("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values))
for parent_id in parent_ids:
try:
@@ -97,47 +96,46 @@ class BeliefPropagation:
new_value=new_pi_msg)
nodes_to_update.add(MsgPassers(msg_receiver=child_id,
msg_sender=node_to_update_label_id))
-
self._belief_propagation(nodes_to_update, evidence)
def initialize_model(self):
"""
- Apply boundary conditions:
+ 1. Apply boundary conditions:
- Set pi_agg equal to prior probabilities for root nodes.
- Set lambda_agg equal to vector of ones for leaf nodes.
- - Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as
- actually passing lambda messages up from leaf nodes to root nodes).
- - Calculate pi_agg and pi_received_msgs for all nodes without evidence.
- (Without evidence, belief equals pi_agg.)
+ 2. Set lambda_agg, lambda_received_msgs to vectors of ones (same effect as
+ actually passing lambda messages up from leaf nodes to root nodes).
+ 3. Calculate pi_agg and pi_received_msgs for all nodes without evidence.
+ (Without evidence, belief equals pi_agg.)
"""
self.model.set_boundary_conditions()
for node in self.model.nodes_dict.values():
ones_vector = np.ones([node.cardinality])
+ node.update_lambda_agg(ones_vector)
- node.lambda_agg = ones_vector
for child in node.lambda_received_msgs.keys():
node.update_lambda_msg_from_child(child=child,
new_value=ones_vector)
- logging.info("Finished initializing Lambda(x) and lambda_received_msgs per node.")
+ logging.debug("Finished initializing Lambda(x) and lambda_received_msgs per node.")
- logging.info("Start downward sweep from nodes. Sending Pi messages only.")
+ logging.debug("Start downward sweep from nodes. Sending Pi messages only.")
topdown_order = self.model.get_topologically_sorted_nodes(reverse=False)
for node_id in topdown_order:
- logging.info('label in iteration through top-down order: %s', str(node_id))
+ logging.debug('label in iteration through top-down order: %s', str(node_id))
node_sending_msg = self.model.nodes_dict[node_id]
child_ids = node_sending_msg.children
- if node_sending_msg.pi_agg is None:
- node_sending_msg.compute_pi_agg()
+ if node_sending_msg.pi_agg.values is None:
+ node_sending_msg.compute_and_update_pi_agg()
for child_id in child_ids:
- logging.info("child: %s", str(child_id))
+ logging.debug("child: %s", str(child_id))
new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id)
- logging.info("new_pi_msg: %s", np.array2string(new_pi_msg))
+ logging.debug("new_pi_msg: %s", np.array2string(new_pi_msg))
child_node = self.model.nodes_dict[child_id]
child_node.update_pi_msg_from_parent(parent=node_id,
@@ -145,42 +143,41 @@ class BeliefPropagation:
def _run_belief_propagation(self, evidence):
"""
- Input:
- evidence: dict
- a dict key, value pair as {var: state_of_var observed}
+ Sequentially perturb nodes with observed values, running belief propagation
+ after each perturbation.
+
+ Args
+ evidence: dict,
+ a dict key, value pair as {var: state_of_var observed}
"""
for evidence_id, observed_value in evidence.items():
- nodes_to_update = set()
-
if evidence_id not in self.model.nodes_dict.keys():
raise KeyError("Evidence supplied for non-existent label_id: {}"
.format(evidence_id))
if is_kronecker_delta(observed_value):
# specific evidence
- self.model.nodes_dict[evidence_id].lambda_agg = observed_value
+ self.model.nodes_dict[evidence_id].update_lambda_agg(observed_value)
else:
# virtual evidence
- self.model.nodes_dict[evidence_id].lambda_agg = \
- self.model.nodes_dict[evidence_id].lambda_agg * observed_value
-
+ self.model.nodes_dict[evidence_id].update_lambda_agg(
+ self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value
+ )
nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)]
-
- self._belief_propagation(nodes_to_update=set(nodes_to_update),
- evidence=evidence)
+ self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence)
def query(self, evidence={}):
"""
- Run belief propagation given evidence.
+ Run belief propagation given 0 or more pieces of evidence.
- Input:
- evidence: dict
- a dict key, value pair as {var: state_of_var observed},
- e.g. {'3': np.array([0,1])} if label '3' is True.
+ Args
+ evidence: dict,
+ a dict key, value pair as {var: state_of_var observed},
+ e.g. {'3': np.array([0,1])} if label '3' is True.
- Returns:
- beliefs: dict
- a dict key, value pair as {var: belief}
+ Returns
+ a dict key, value pair as {var: belief}, where belief is an np.array of the
+ marginal probability of each state of the variable given the evidence.
Example
-------