aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors/BernoulliOrFactor.py
blob: 4f973ae20b86ae70a29302afd296288f8ee6cc37 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np


class BernoulliOrFactor:
    """CPD class for a Bernoulli random variable whose relationship to its
    parents is described by OR logic.

    If at least one of a child's parents is True, then the child is True, and
    False otherwise."""
    def __init__(self, child, parents=set()):
        self.child = child
        self.parents = set(parents)
        self.variables = set([child] + list(parents))
        self.cardinality = [2]*len(self.variables)
        self._values = None

    @property
    def values(self):
        if self._values is None:
            self._values = self._build_kwise_values_array(len(self.variables))
            self._values = self._values.reshape(self.cardinality)
        return self._values

    def get_values(self):
        """
        Returns the tabular cpd form of the values.
        """
        if len(self.cardinality) == 1:
            return self.values.reshape(1, np.prod(self.cardinality))
        else:
            return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))

    @staticmethod
    def _build_kwise_values_array(k):
        # special case a completely independent factor, and
        # return the uniform prior
        if k == 1:
            return np.array([0.5, 0.5])

        return np.array(
            [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
        )