diff options
author | Cathy Yeh <cathy@driver.xyz> | 2017-12-11 18:56:15 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2017-12-11 18:56:15 -0800 |
commit | 65d822247e30b6e104a8c09d3b930487b9f20a58 (patch) | |
tree | d44b83f001ab352b30e17ab981295c2ee70a4d56 /beliefs/factors | |
parent | 26b43410569044aff46053cae7c68862825dd4ec (diff) | |
parent | 7b5c17c316481edbbd13815390d0b34fb50a03a6 (diff) | |
download | beliefs-e3e0589969b0660d7fa94bf55515d5ba31f5c6e7.tar.gz beliefs-e3e0589969b0660d7fa94bf55515d5ba31f5c6e7.tar.bz2 beliefs-e3e0589969b0660d7fa94bf55515d5ba31f5c6e7.zip |
LGS-173 Merge branch 'bernoulli_and_node'v0.0.3
Diffstat (limited to 'beliefs/factors')
-rw-r--r-- | beliefs/factors/bernoulli_and_cpd.py | 45 | ||||
-rw-r--r-- | beliefs/factors/bernoulli_or_cpd.py | 7 |
2 files changed, 50 insertions, 2 deletions
diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py new file mode 100644 index 0000000..fdb0c25 --- /dev/null +++ b/beliefs/factors/bernoulli_and_cpd.py @@ -0,0 +1,45 @@ +import numpy as np + +from beliefs.factors.cpd import TabularCPD + + +class BernoulliAndCPD(TabularCPD): + """CPD class for a Bernoulli random variable whose relationship to its + parents (also Bernoulli random variables) is described by AND logic. + + If all of the variable's parents are True, then the variable + is True, and False otherwise. + """ + def __init__(self, variable, parents=[]): + """ + Args: + variable: int or string + parents: optional, list of int and/or strings + """ + super().__init__(variable=variable, + variable_card=2, + parents=parents, + parents_card=[2]*len(parents), + values=[]) + self._values = None + + @property + def values(self): + if self._values is None: + self._values = self._build_kwise_values_array(len(self.variables)) + self._values = self._values.reshape(self.cardinality) + return self._values + + @staticmethod + def _build_kwise_values_array(k): + # special case a completely independent factor, and + # return the uniform prior + if k == 1: + return np.array([0.5, 0.5]) + + # values are stored as a row vector using an ordering such that + # the right-most variables as defined in [variable].extend(parents) + # cycle through their values the fastest. + return np.array( + [1.]*(2**(k-1)-1) + [0.] + [0.,]*(2**(k-1)-1) + [1.] + ) diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py index bfb3a95..12ee2f6 100644 --- a/beliefs/factors/bernoulli_or_cpd.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -21,11 +21,11 @@ class BernoulliOrCPD(TabularCPD): parents=parents, parents_card=[2]*len(parents), values=[]) - self._values = [] + self._values = None @property def values(self): - if not any(self._values): + if self._values is None: self._values = self._build_kwise_values_array(len(self.variables)) self._values = self._values.reshape(self.cardinality) return self._values @@ -37,6 +37,9 @@ class BernoulliOrCPD(TabularCPD): if k == 1: return np.array([0.5, 0.5]) + # values are stored as a row vector using an ordering such that + # the right-most variables as defined in [variable].extend(parents) + # cycle through their values the fastest. return np.array( [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1) ) |