aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-08 16:00:14 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-08 16:00:14 -0800
commit8cdb00cdb10200e824015ece4a94485e93857352 (patch)
treeed0aa04912fd227825797b0fa6b89bbab76d5eaa
parent26b43410569044aff46053cae7c68862825dd4ec (diff)
downloadbeliefs-8cdb00cdb10200e824015ece4a94485e93857352.tar.gz
beliefs-8cdb00cdb10200e824015ece4a94485e93857352.tar.bz2
beliefs-8cdb00cdb10200e824015ece4a94485e93857352.zip
bernoulli AND cpd
-rw-r--r--beliefs/factors/bernoulli_and_cpd.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py
new file mode 100644
index 0000000..fb86135
--- /dev/null
+++ b/beliefs/factors/bernoulli_and_cpd.py
@@ -0,0 +1,42 @@
+import numpy as np
+
+from beliefs.factors.cpd import TabularCPD
+
+
+class BernoulliAndCPD(TabularCPD):
+ """CPD class for a Bernoulli random variable whose relationship to its
+ parents (also Bernoulli random variables) is described by AND logic.
+
+ If all of the variable's parents are True, then the variable
+ is True, and False otherwise.
+ """
+ def __init__(self, variable, parents=[]):
+ """
+ Args:
+ variable: int or string
+ parents: optional, list of int and/or strings
+ """
+ super().__init__(variable=variable,
+ variable_card=2,
+ parents=parents,
+ parents_card=[2]*len(parents),
+ values=[])
+ self._values = []
+
+ @property
+ def values(self):
+ if len(self._values) == 0:
+ self._values = self._build_kwise_values_array(len(self.variables))
+ self._values = self._values.reshape(self.cardinality)
+ return self._values
+
+ @staticmethod
+ def _build_kwise_values_array(k):
+ # special case a completely independent factor, and
+ # return the uniform prior
+ if k == 1:
+ return np.array([0.5, 0.5])
+
+ return np.array(
+ [1.]*(2**(k-1)-1) + [0.] + [0.,]*(2**(k-1)-1) + [1.]
+ )