aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-11-13 14:42:52 -0800
committerCathy Yeh <cathy@driver.xyz>2017-11-17 13:48:16 -0800
commit77d8b323d4f6e05ca97d9cbef43ac85fd8040d61 (patch)
treebd589afff10efce13b6f017e544958454f3a8ef7 /beliefs/factors
parent6a1b35f5bf122232d058ed0f3ea19c15629c0cbc (diff)
downloadbeliefs-77d8b323d4f6e05ca97d9cbef43ac85fd8040d61.tar.gz
beliefs-77d8b323d4f6e05ca97d9cbef43ac85fd8040d61.tar.bz2
beliefs-77d8b323d4f6e05ca97d9cbef43ac85fd8040d61.zip
copy scripts from lgs branch
Diffstat (limited to 'beliefs/factors')
-rw-r--r--beliefs/factors/BernoulliOrFactor.py42
-rw-r--r--beliefs/factors/__init__.py0
2 files changed, 42 insertions, 0 deletions
diff --git a/beliefs/factors/BernoulliOrFactor.py b/beliefs/factors/BernoulliOrFactor.py
new file mode 100644
index 0000000..4f973ae
--- /dev/null
+++ b/beliefs/factors/BernoulliOrFactor.py
@@ -0,0 +1,42 @@
+import numpy as np
+
+
+class BernoulliOrFactor:
+ """CPD class for a Bernoulli random variable whose relationship to its
+ parents is described by OR logic.
+
+ If at least one of a child's parents is True, then the child is True, and
+ False otherwise."""
+ def __init__(self, child, parents=set()):
+ self.child = child
+ self.parents = set(parents)
+ self.variables = set([child] + list(parents))
+ self.cardinality = [2]*len(self.variables)
+ self._values = None
+
+ @property
+ def values(self):
+ if self._values is None:
+ self._values = self._build_kwise_values_array(len(self.variables))
+ self._values = self._values.reshape(self.cardinality)
+ return self._values
+
+ def get_values(self):
+ """
+ Returns the tabular cpd form of the values.
+ """
+ if len(self.cardinality) == 1:
+ return self.values.reshape(1, np.prod(self.cardinality))
+ else:
+ return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
+
+ @staticmethod
+ def _build_kwise_values_array(k):
+ # special case a completely independent factor, and
+ # return the uniform prior
+ if k == 1:
+ return np.array([0.5, 0.5])
+
+ return np.array(
+ [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
+ )
diff --git a/beliefs/factors/__init__.py b/beliefs/factors/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/beliefs/factors/__init__.py