aboutsummaryrefslogtreecommitdiff
path: root/beliefs/models/beliefupdate/BernoulliOrNode.py
blob: 338627556dc324e11105678aa255f77c0c82d979 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
from functools import reduce

from beliefs.models.beliefupdate.Node import (
    Node,
    MessageType,
    InvalidLambdaMsgToParent
)
from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD


class BernoulliOrNode(Node):
    def __init__(self,
                 label_id,
                 children,
                 parents):
        super().__init__(label_id=label_id,
                         children=children,
                         parents=parents,
                         cardinality=2,
                         cpd=BernoulliOrCPD(label_id, parents))

    def compute_pi_agg(self):
        if not self.parents:
            self.pi_agg = self.cpd.values
        else:
            pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
            parents_p0 = [p[0] for p in pi_msg_values]
            p_0 = reduce(lambda x, y: x*y, parents_p0)
            p_1 = 1 - p_0
            self.pi_agg = np.array([p_0, p_1])
        return self.pi_agg

    def compute_lambda_msg_to_parent(self, parent_k):
        if np.array_equal(self.lambda_agg, np.ones([self.cardinality])):
            return np.ones([self.cardinality])
        else:
            # TODO: cleanup this validation
            _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
            p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k]
            p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
            lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product
            lambda_1 = self.lambda_agg[1]
            lambda_msg = np.array([lambda_0, lambda_1])
            if not any(lambda_msg):
                raise InvalidLambdaMsgToParent
            return self._normalize(lambda_msg)