blob: 4f973ae20b86ae70a29302afd296288f8ee6cc37 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
import numpy as np
class BernoulliOrFactor:
"""CPD class for a Bernoulli random variable whose relationship to its
parents is described by OR logic.
If at least one of a child's parents is True, then the child is True, and
False otherwise."""
def __init__(self, child, parents=set()):
self.child = child
self.parents = set(parents)
self.variables = set([child] + list(parents))
self.cardinality = [2]*len(self.variables)
self._values = None
@property
def values(self):
if self._values is None:
self._values = self._build_kwise_values_array(len(self.variables))
self._values = self._values.reshape(self.cardinality)
return self._values
def get_values(self):
"""
Returns the tabular cpd form of the values.
"""
if len(self.cardinality) == 1:
return self.values.reshape(1, np.prod(self.cardinality))
else:
return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
@staticmethod
def _build_kwise_values_array(k):
# special case a completely independent factor, and
# return the uniform prior
if k == 1:
return np.array([0.5, 0.5])
return np.array(
[1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
)
|