aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-11 11:39:04 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-11 18:50:06 -0800
commit00dfdd7a897b2606ceeabf5323e71d8e80a446fc (patch)
treef74ca47246481f444ca369747da5c55c25455027 /beliefs/factors
parent06626854ca893b44c128ca333fb5623591134746 (diff)
downloadbeliefs-00dfdd7a897b2606ceeabf5323e71d8e80a446fc.tar.gz
beliefs-00dfdd7a897b2606ceeabf5323e71d8e80a446fc.tar.bz2
beliefs-00dfdd7a897b2606ceeabf5323e71d8e80a446fc.zip
PR comments
Diffstat (limited to 'beliefs/factors')
-rw-r--r--beliefs/factors/bernoulli_and_cpd.py7
-rw-r--r--beliefs/factors/bernoulli_or_cpd.py7
2 files changed, 10 insertions, 4 deletions
diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py
index fb86135..fdb0c25 100644
--- a/beliefs/factors/bernoulli_and_cpd.py
+++ b/beliefs/factors/bernoulli_and_cpd.py
@@ -21,11 +21,11 @@ class BernoulliAndCPD(TabularCPD):
parents=parents,
parents_card=[2]*len(parents),
values=[])
- self._values = []
+ self._values = None
@property
def values(self):
- if len(self._values) == 0:
+ if self._values is None:
self._values = self._build_kwise_values_array(len(self.variables))
self._values = self._values.reshape(self.cardinality)
return self._values
@@ -37,6 +37,9 @@ class BernoulliAndCPD(TabularCPD):
if k == 1:
return np.array([0.5, 0.5])
+ # values are stored as a row vector using an ordering such that
+ # the right-most variables as defined in [variable].extend(parents)
+ # cycle through their values the fastest.
return np.array(
[1.]*(2**(k-1)-1) + [0.] + [0.,]*(2**(k-1)-1) + [1.]
)
diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py
index 162e156..12ee2f6 100644
--- a/beliefs/factors/bernoulli_or_cpd.py
+++ b/beliefs/factors/bernoulli_or_cpd.py
@@ -21,11 +21,11 @@ class BernoulliOrCPD(TabularCPD):
parents=parents,
parents_card=[2]*len(parents),
values=[])
- self._values = []
+ self._values = None
@property
def values(self):
- if len(self._values) == 0:
+ if self._values is None:
self._values = self._build_kwise_values_array(len(self.variables))
self._values = self._values.reshape(self.cardinality)
return self._values
@@ -37,6 +37,9 @@ class BernoulliOrCPD(TabularCPD):
if k == 1:
return np.array([0.5, 0.5])
+ # values are stored as a row vector using an ordering such that
+ # the right-most variables as defined in [variable].extend(parents)
+ # cycle through their values the fastest.
return np.array(
[1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
)