aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/factors')
-rw-r--r--beliefs/factors/__init__.py0
-rw-r--r--beliefs/factors/bernoulli_or_cpd.py42
-rw-r--r--beliefs/factors/cpd.py45
3 files changed, 87 insertions, 0 deletions
diff --git a/beliefs/factors/__init__.py b/beliefs/factors/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/beliefs/factors/__init__.py
diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py
new file mode 100644
index 0000000..bfb3a95
--- /dev/null
+++ b/beliefs/factors/bernoulli_or_cpd.py
@@ -0,0 +1,42 @@
+import numpy as np
+
+from beliefs.factors.cpd import TabularCPD
+
+
+class BernoulliOrCPD(TabularCPD):
+ """CPD class for a Bernoulli random variable whose relationship to its
+ parents (also Bernoulli random variables) is described by OR logic.
+
+ If at least one of the variable's parents is True, then the variable
+ is True, and False otherwise.
+ """
+ def __init__(self, variable, parents=[]):
+ """
+ Args:
+ variable: int or string
+ parents: optional, list of int and/or strings
+ """
+ super().__init__(variable=variable,
+ variable_card=2,
+ parents=parents,
+ parents_card=[2]*len(parents),
+ values=[])
+ self._values = []
+
+ @property
+ def values(self):
+ if not any(self._values):
+ self._values = self._build_kwise_values_array(len(self.variables))
+ self._values = self._values.reshape(self.cardinality)
+ return self._values
+
+ @staticmethod
+ def _build_kwise_values_array(k):
+ # special case a completely independent factor, and
+ # return the uniform prior
+ if k == 1:
+ return np.array([0.5, 0.5])
+
+ return np.array(
+ [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
+ )
diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py
new file mode 100644
index 0000000..a286aaa
--- /dev/null
+++ b/beliefs/factors/cpd.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+
+class TabularCPD:
+ """
+ Defines the conditional probability table for a discrete variable
+ whose parents are also discrete.
+
+ TODO: have this inherit from DiscreteFactor implementing explicit factor methods
+ """
+ def __init__(self, variable, variable_card,
+ parents=[], parents_card=[], values=[]):
+ """
+ Args:
+ variable: int or string
+ variable_card: int
+ parents: optional, list of int and/or strings
+ parents_card: optional, list of int
+ values: optional, 2d list or array
+ """
+ self.variable = variable
+ self.parents = parents
+ self.variables = [variable] + parents
+ self.cardinality = [variable_card] + parents_card
+ self._values = np.array(values)
+
+ @property
+ def values(self):
+ return self._values
+
+ def get_values(self):
+ """
+ Returns the tabular cpd form of the values.
+ """
+ if len(self.cardinality) == 1:
+ return self.values.reshape(1, np.prod(self.cardinality))
+ else:
+ return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
+
+ def copy(self):
+ return self.__class__(self.variable,
+ self.cardinality[0],
+ self.parents,
+ self.cardinality[1:],
+ self._values)