aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-11 11:39:04 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-11 18:50:06 -0800
commit00dfdd7a897b2606ceeabf5323e71d8e80a446fc (patch)
treef74ca47246481f444ca369747da5c55c25455027
parent06626854ca893b44c128ca333fb5623591134746 (diff)
downloadbeliefs-00dfdd7a897b2606ceeabf5323e71d8e80a446fc.tar.gz
beliefs-00dfdd7a897b2606ceeabf5323e71d8e80a446fc.tar.bz2
beliefs-00dfdd7a897b2606ceeabf5323e71d8e80a446fc.zip
PR comments
-rw-r--r--beliefs/factors/bernoulli_and_cpd.py7
-rw-r--r--beliefs/factors/bernoulli_or_cpd.py7
-rw-r--r--beliefs/models/belief_update_node_model.py14
3 files changed, 19 insertions, 9 deletions
diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py
index fb86135..fdb0c25 100644
--- a/beliefs/factors/bernoulli_and_cpd.py
+++ b/beliefs/factors/bernoulli_and_cpd.py
@@ -21,11 +21,11 @@ class BernoulliAndCPD(TabularCPD):
parents=parents,
parents_card=[2]*len(parents),
values=[])
- self._values = []
+ self._values = None
@property
def values(self):
- if len(self._values) == 0:
+ if self._values is None:
self._values = self._build_kwise_values_array(len(self.variables))
self._values = self._values.reshape(self.cardinality)
return self._values
@@ -37,6 +37,9 @@ class BernoulliAndCPD(TabularCPD):
if k == 1:
return np.array([0.5, 0.5])
+ # values are stored as a row vector using an ordering such that
+ # the right-most variables as defined in [variable].extend(parents)
+ # cycle through their values the fastest.
return np.array(
[1.]*(2**(k-1)-1) + [0.] + [0.,]*(2**(k-1)-1) + [1.]
)
diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py
index 162e156..12ee2f6 100644
--- a/beliefs/factors/bernoulli_or_cpd.py
+++ b/beliefs/factors/bernoulli_or_cpd.py
@@ -21,11 +21,11 @@ class BernoulliOrCPD(TabularCPD):
parents=parents,
parents_card=[2]*len(parents),
values=[])
- self._values = []
+ self._values = None
@property
def values(self):
- if len(self._values) == 0:
+ if self._values is None:
self._values = self._build_kwise_values_array(len(self.variables))
self._values = self._values.reshape(self.cardinality)
return self._values
@@ -37,6 +37,9 @@ class BernoulliOrCPD(TabularCPD):
if k == 1:
return np.array([0.5, 0.5])
+ # values are stored as a row vector using an ordering such that
+ # the right-most variables as defined in [variable].extend(parents)
+ # cycle through their values the fastest.
return np.array(
[1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
)
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index 4747530..1c3ba6e 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -213,7 +213,7 @@ class Node:
raise NotImplementedError
def compute_lambda_agg(self):
- if not self.children:
+ if len(self.children) == 0:
return self.lambda_agg
else:
lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
@@ -290,11 +290,13 @@ class BernoulliOrNode(Node):
cpd=BernoulliOrCPD(label_id, parents))
def compute_pi_agg(self):
- if not self.parents:
+ if len(self.parents) == 0:
self.pi_agg = self.cpd.values
else:
pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
parents_p0 = [p[0] for p in pi_msg_values]
+ # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities)
+ # of p = [P(False), P(True)]
p_0 = reduce(lambda x, y: x*y, parents_p0)
p_1 = 1 - p_0
self.pi_agg = np.array([p_0, p_1])
@@ -306,7 +308,7 @@ class BernoulliOrNode(Node):
else:
# TODO: cleanup this validation
_ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k]
+ p0_excluding_k = [p[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product
lambda_1 = self.lambda_agg[1]
@@ -328,11 +330,13 @@ class BernoulliAndNode(Node):
cpd=BernoulliAndCPD(label_id, parents))
def compute_pi_agg(self):
- if not self.parents:
+ if len(self.parents) == 0:
self.pi_agg = self.cpd.values
else:
pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
parents_p1 = [p[1] for p in pi_msg_values]
+ # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities)
+ # of p = [P(False), P(True)]
p_1 = reduce(lambda x, y: x*y, parents_p1)
p_0 = 1 - p_1
self.pi_agg = np.array([p_0, p_1])
@@ -344,7 +348,7 @@ class BernoulliAndNode(Node):
else:
# TODO: cleanup this validation
_ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- p1_excluding_k = [msg[1] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k]
+ p1_excluding_k = [p[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1)
lambda_0 = self.lambda_agg[0]
lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product