aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2018-01-17 17:32:07 -0800
committerCathy Yeh <cathy@driver.xyz>2018-01-17 17:32:07 -0800
commit1a9286b0c1698fe5329a8e0b2a886f0a98286d2b (patch)
tree24942e42448c4175acbf1ab4e4e2d5d0cea44e88
parentd92ed9f14baead60fdd6c1d823345cc3ddd1bc04 (diff)
downloadbeliefs-1a9286b0c1698fe5329a8e0b2a886f0a98286d2b.tar.gz
beliefs-1a9286b0c1698fe5329a8e0b2a886f0a98286d2b.tar.bz2
beliefs-1a9286b0c1698fe5329a8e0b2a886f0a98286d2b.zip
compute_pi_agg -> compute_and_update_pi_agg, compute_lambda_agg -> compute_and_update_lambda_agg
-rw-r--r--beliefs/inference/belief_propagation.py6
-rw-r--r--beliefs/models/belief_update_node_model.py8
2 files changed, 7 insertions, 7 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index acd93d4..e6e7b18 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -72,9 +72,9 @@ class BeliefPropagation:
assert len(node.parents) + len(node.children) - 1 == len(parent_ids) + len(child_ids)
if node_to_update_label_id not in evidence:
- node.compute_pi_agg()
+ node.compute_and_update_pi_agg()
logging.debug("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values))
- node.compute_lambda_agg()
+ node.compute_and_update_lambda_agg()
logging.debug("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values))
for parent_id in parent_ids:
@@ -130,7 +130,7 @@ class BeliefPropagation:
child_ids = node_sending_msg.children
if node_sending_msg.pi_agg.values is None:
- node_sending_msg.compute_pi_agg()
+ node_sending_msg.compute_and_update_pi_agg()
for child_id in child_ids:
logging.debug("child: %s", str(child_id))
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index 743bbcb..ec329ca 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -236,7 +236,7 @@ class Node:
else:
return msgs
- def compute_pi_agg(self):
+ def compute_and_update_pi_agg(self):
"""
Compute and update pi_agg, the prior probability, given the current state
of messages received from parents.
@@ -252,7 +252,7 @@ class Node:
self.update_pi_agg(factor_product.marginalize(self.parents).values)
pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- def compute_lambda_agg(self):
+ def compute_and_update_lambda_agg(self):
"""
Compute and update lambda_agg, the likelihood, given the current state
of messages received from children.
@@ -370,7 +370,7 @@ class BernoulliOrNode(Node):
def __init__(self, label_id, children, parents):
super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents))
- def compute_pi_agg(self):
+ def compute_and_update_pi_agg(self):
"""
Compute and update pi_agg, the prior probability, given the current state
of messages received from parents. Sidestep explicit factor product and
@@ -424,7 +424,7 @@ class BernoulliAndNode(Node):
def __init__(self, label_id, children, parents):
super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents))
- def compute_pi_agg(self):
+ def compute_and_update_pi_agg(self):
"""
Compute and update pi_agg, the prior probability, given the current state
of messages received from parents. Sidestep explicit factor product and