aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors/bernoulli_and_cpd.py
blob: 291398f0aae4d18468516ae93b86e67901c338cc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np

from beliefs.factors.cpd import TabularCPD


class BernoulliAndCPD(TabularCPD):
    """CPD class for a Bernoulli random variable whose relationship to its
    parents (also Bernoulli random variables) is described by AND logic.

    If all of the variable's parents are True, then the variable
    is True, and False otherwise.
    """
    def __init__(self, variable, parents=[]):
        """
        Args
            variable: int or string
            parents: list,
                (optional) list of int and/or strings
        """
        super().__init__(variable=variable,
                         variable_card=2,
                         parents=parents,
                         parents_card=[2]*len(parents),
                         values=None,
                         state_names={var: ['False', 'True'] for var in [variable] + parents})
        self._values = None

    @property
    def values(self):
        if self._values is None:
            self._values = self._build_kwise_values_array(len(self.variables))
            self._values = self._values.reshape(self.cardinality)
        return self._values

    @staticmethod
    def _build_kwise_values_array(k):
        # special case a completely independent factor, and
        # return the uniform prior
        if k == 1:
            return np.array([0.5, 0.5])

        # values are stored as a row vector using an ordering such that
        # the right-most variables as defined in [variable].extend(parents)
        # cycle through their values the fastest.
        return np.array(
            [1.]*(2**(k-1)-1) + [0.] + [0.,]*(2**(k-1)-1) + [1.]
        )