aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-08 16:00:52 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-08 16:00:57 -0800
commit4373157138e85d2dbad9672cef5963a27a3d962c (patch)
tree7fc7d5a6ec8bf4224a0c8cc1dd9cbe3d851b0cef
parent8cdb00cdb10200e824015ece4a94485e93857352 (diff)
downloadbeliefs-4373157138e85d2dbad9672cef5963a27a3d962c.tar.gz
beliefs-4373157138e85d2dbad9672cef5963a27a3d962c.tar.bz2
beliefs-4373157138e85d2dbad9672cef5963a27a3d962c.zip
BernoulliAndNode with custom msg passing methods
-rw-r--r--beliefs/models/belief_update_node_model.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index 667e0f1..4747530 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -8,6 +8,7 @@ import networkx as nx
from beliefs.models.base_models import BayesianModel
from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD
+from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD
class InvalidLambdaMsgToParent(Exception):
@@ -313,3 +314,41 @@ class BernoulliOrNode(Node):
if not any(lambda_msg):
raise InvalidLambdaMsgToParent
return self._normalize(lambda_msg)
+
+
+class BernoulliAndNode(Node):
+ def __init__(self,
+ label_id,
+ children,
+ parents):
+ super().__init__(label_id=label_id,
+ children=children,
+ parents=parents,
+ cardinality=2,
+ cpd=BernoulliAndCPD(label_id, parents))
+
+ def compute_pi_agg(self):
+ if not self.parents:
+ self.pi_agg = self.cpd.values
+ else:
+ pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ parents_p1 = [p[1] for p in pi_msg_values]
+ p_1 = reduce(lambda x, y: x*y, parents_p1)
+ p_0 = 1 - p_1
+ self.pi_agg = np.array([p_0, p_1])
+ return self.pi_agg
+
+ def compute_lambda_msg_to_parent(self, parent_k):
+ if np.array_equal(self.lambda_agg, np.ones([self.cardinality])):
+ return np.ones([self.cardinality])
+ else:
+ # TODO: cleanup this validation
+ _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ p1_excluding_k = [msg[1] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k]
+ p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1)
+ lambda_0 = self.lambda_agg[0]
+ lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product
+ lambda_msg = np.array([lambda_0, lambda_1])
+ if not any(lambda_msg):
+ raise InvalidLambdaMsgToParent
+ return self._normalize(lambda_msg)